diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 000000000..31898182e --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,11 @@ +#!/bin/bash + +echo "Running pre-push hook..." +if ! make lint; then + echo "" + echo "Hint: To push without verification, run:" + echo " git push --no-verify" + exit 1 +fi + +echo "All checks passed!" diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d3da427b0..543ba2ab2 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -3,39 +3,108 @@ name: Check License Dependencies on: push: branches: [ main ] + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' pull_request: + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' jobs: - check-dependencies: + check-internal-dependencies: + name: Check Internal AGPL Dependencies + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Check for problematic license dependencies + run: | + echo "Checking for dependencies on management/, signal/, and relay/ packages..." + echo "" + + # Find all directories except the problematic ones and system dirs + FOUND_ISSUES=0 + while IFS= read -r dir; do + echo "=== Checking $dir ===" + # Search for problematic imports, excluding test files + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + if [ -n "$RESULTS" ]; then + echo "❌ Found problematic dependencies:" + echo "$RESULTS" + FOUND_ISSUES=1 + else + echo "✓ No problematic dependencies found" + fi + done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort) + + echo "" + if [ $FOUND_ISSUES -eq 1 ]; then + echo "❌ Found dependencies on management/, signal/, or relay/ packages" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" + exit 1 + else + echo "" + echo "✅ All internal license dependencies are clean" + fi + + check-external-licenses: + name: Check External GPL/AGPL Licenses runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Check for problematic license dependencies - run: | - echo "Checking for dependencies on management/, signal/, and relay/ packages..." + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + cache: true - # Find all directories except the problematic ones and system dirs - FOUND_ISSUES=0 - find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do - echo "=== Checking $dir ===" - # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) - if [ ! -z "$RESULTS" ]; then - echo "❌ Found problematic dependencies:" - echo "$RESULTS" - FOUND_ISSUES=1 - else - echo "✓ No problematic dependencies found" + - name: Install go-licenses + run: go install github.com/google/go-licenses@v1.6.0 + + - name: Check for GPL/AGPL licensed dependencies + run: | + echo "Checking for GPL/AGPL/LGPL licensed dependencies..." + echo "" + + # Check all Go packages for copyleft licenses, excluding internal netbird packages + COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true) + + if [ -n "$COPYLEFT_DEPS" ]; then + echo "Found copyleft licensed dependencies:" + echo "$COPYLEFT_DEPS" + echo "" + + # Filter out dependencies that are only pulled in by internal AGPL packages + INCOMPATIBLE="" + while IFS=',' read -r package url license; do + if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then + # Find ALL packages that import this GPL package using go list + IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") + + # Check if any importer is NOT in management/signal/relay + BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1) + + if [ -n "$BSD_IMPORTER" ]; then + echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" + INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n" + else + echo "✓ $package ($license) is only used by internal AGPL packages - OK" + fi + fi + done <<< "$COPYLEFT_DEPS" + + if [ -n "$INCOMPATIBLE" ]; then + echo "" + echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:" + echo -e "$INCOMPATIBLE" + exit 1 fi - done - if [ $FOUND_ISSUES -eq 1 ]; then - echo "" - echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages will change license and should not be imported by client or shared code" - exit 1 - else - echo "" - echo "✅ All license dependencies are clean" fi + + echo "✅ All external license dependencies are compatible with BSD-3-Clause" diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 4571ce753..9c4c35d21 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -15,13 +15,14 @@ jobs: name: "Client / Unit" runs-on: macos-latest steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index cdd0910a4..b03313bbd 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -25,7 +25,7 @@ jobs: release: "14.2" prepare: | pkg install -y curl pkgconf xorg - GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz" + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" GO_URL="https://go.dev/dl/$GO_TARBALL" curl -vLO "$GO_URL" tar -C /usr/local -vxzf "$GO_TARBALL" diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba36c013b..c09bfab39 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -30,7 +30,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment @@ -106,15 +106,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -151,15 +151,15 @@ jobs: needs: [ build-cache ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment id: go-env run: | @@ -200,7 +200,7 @@ jobs: -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e CONTAINER=${CONTAINER} \ - golang:1.23-alpine \ + golang:1.24-alpine \ sh -c ' \ apk update; apk add --no-cache \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ @@ -220,15 +220,15 @@ jobs: raceFlag: "-race" runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -270,15 +270,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -321,15 +321,15 @@ jobs: store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -408,15 +408,16 @@ jobs: -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ -p 9090:9090 \ prom/prometheus - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + cache: false + - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -497,15 +498,15 @@ jobs: -p 9090:9090 \ prom/prometheus + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -561,15 +562,15 @@ jobs: store: [ 'sqlite', 'postgres'] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2083c0721..43357c45f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,7 +24,7 @@ jobs: uses: actions/setup-go@v5 id: go with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2845b05a5..c524f6f6b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index c7d43695b..8325fbf2d 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Setup Android SDK uses: android-actions/setup-android@v3 with: @@ -39,7 +39,7 @@ jobs: - name: Setup NDK run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build android netbird lib @@ -56,9 +56,9 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build iOS netbird lib diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9741f541..2fa847dce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.23" + SIGN_PIPE_VER: "v0.1.0" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -19,8 +19,102 @@ concurrency: cancel-in-progress: true jobs: - release: + release_freebsd_port: + name: "FreeBSD Port / Build & Test" runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Generate FreeBSD port diff + run: bash release_files/freebsd-port-diff.sh + + - name: Generate FreeBSD port issue body + run: bash release_files/freebsd-port-issue-body.sh + + - name: Check if diff was generated + id: check_diff + run: | + if ls netbird-*.diff 1> /dev/null 2>&1; then + echo "diff_exists=true" >> $GITHUB_OUTPUT + else + echo "diff_exists=false" >> $GITHUB_OUTPUT + echo "No diff file generated (port may already be up to date)" + fi + + - name: Extract version + if: steps.check_diff.outputs.diff_exists == 'true' + id: version + run: | + VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/') + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Generated files for version: $VERSION" + cat netbird-*.diff + + - name: Test FreeBSD port + if: steps.check_diff.outputs.diff_exists == 'true' + uses: vmactions/freebsd-vm@v1 + with: + usesh: true + copyback: false + release: "15.0" + prepare: | + # Install required packages + pkg install -y git curl portlint go + + # Install Go for building + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" + GO_URL="https://go.dev/dl/$GO_TARBALL" + curl -LO "$GO_URL" + tar -C /usr/local -xzf "$GO_TARBALL" + + # Clone ports tree (shallow, only what we need) + git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports + cd /usr/ports + + run: | + set -e -x + export PATH=$PATH:/usr/local/go/bin + + # Find the diff file + echo "Finding diff file..." + DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1) + echo "Found: $DIFF_FILE" + + if [[ -z "$DIFF_FILE" ]]; then + echo "ERROR: Could not find diff file" + find ~ -name "*.diff" -type f 2>/dev/null || true + exit 1 + fi + + # Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths) + cd /usr/ports + patch -p1 -V none < "$DIFF_FILE" + + # Show patched Makefile + version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}') + + cd /usr/ports/security/netbird + export BATCH=yes + make package + pkg add ./work/pkg/netbird-*.pkg + + netbird version | grep "$version" + + echo "FreeBSD port test completed successfully!" + + - name: Upload FreeBSD port files + if: steps.check_diff.outputs.diff_exists == 'true' + uses: actions/upload-artifact@v4 + with: + name: freebsd-port-files + path: | + ./netbird-*-issue.txt + ./netbird-*.diff + retention-days: 30 + + release: + runs-on: ubuntu-latest-m env: flags: "" steps: @@ -40,7 +134,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -136,7 +230,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -200,7 +294,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index 3855baba2..f4513e0e1 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -67,10 +67,13 @@ jobs: - name: Install curl run: sudo apt-get install -y curl + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Cache Go modules uses: actions/cache@v4 @@ -80,9 +83,6 @@ jobs: restore-keys: | ${{ runner.os }}-go- - - name: Checkout code - uses: actions/checkout@v4 - - name: Setup MySQL privileges if: matrix.store == 'mysql' run: | diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index e4ac799bc..4100e16dd 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: Install golangci-lint @@ -45,7 +45,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Build Wasm client run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd env: @@ -60,8 +60,8 @@ jobs: echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" - if [ ${SIZE} -gt 52428800 ]; then - echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + if [ ${SIZE} -gt 57671680 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" exit 1 fi diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c82cfc763..efc7d9460 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -136,6 +136,14 @@ checked out and set up: go mod tidy ``` +6. Configure Git hooks for automatic linting: + + ```bash + make setup-hooks + ``` + + This will configure Git to run linting automatically before each push, helping catch issues early. + ### Dev Container Support If you prefer using a dev container for development, NetBird now includes support for dev containers. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..43379e115 --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +.PHONY: lint lint-all lint-install setup-hooks +GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint + +# Install golangci-lint locally if needed +$(GOLANGCI_LINT): + @echo "Installing golangci-lint..." + @mkdir -p ./bin + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +# Lint only changed files (fast, for pre-push) +lint: $(GOLANGCI_LINT) + @echo "Running lint on changed files..." + @$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m + +# Lint entire codebase (slow, matches CI) +lint-all: $(GOLANGCI_LINT) + @echo "Running lint on all files..." + @$(GOLANGCI_LINT) run --timeout=12m + +# Just install the linter +lint-install: $(GOLANGCI_LINT) + +# Setup git hooks for all developers +setup-hooks: + @git config core.hooksPath .githooks + @chmod +x .githooks/pre-push + @echo "✅ Git hooks configured! Pre-push will now run 'make lint'" diff --git a/README.md b/README.md index 2c5ee2ab6..ebf108cdb 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.

- +

See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/android/client.go b/client/android/client.go index d2d0c37f6..ccf32a90c 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,10 +4,13 @@ package android import ( "context" + "fmt" "os" "slices" "sync" + "golang.org/x/exp/maps" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/device" @@ -16,10 +19,13 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -53,7 +59,6 @@ func init() { // Client struct manage the life circle of background service type Client struct { - cfgFile string tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status @@ -67,12 +72,11 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: cfgFile, deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -84,10 +88,16 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -107,23 +117,29 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.login(urlOpener) + err = auth.login(urlOpener, isAndroidTV) if err != nil { return err } // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -141,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // Stop the internal client and free the resources @@ -156,6 +172,19 @@ func (c *Client) Stop() { c.ctxCancel() } +func (c *Client) RenewTun(fd int) error { + if c.connectClient == nil { + return fmt.Errorf("engine not running") + } + + e := c.connectClient.Engine() + if e == nil { + return fmt.Errorf("engine not initialized") + } + + return e.RenewTun(fd) +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -177,6 +206,7 @@ func (c *Client) PeersList() *PeerInfoArray { p.IP, p.FQDN, p.ConnStatus.String(), + PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -201,31 +231,43 @@ func (c *Client) Networks() *NetworkArray { return nil } + routeSelector := routeManager.GetRouteSelector() + if routeSelector == nil { + log.Error("could not get route selector") + return nil + } + networkArray := &NetworkArray{ items: make([]Network, 0), } + resolvedDomains := c.recorder.GetResolvedDomainsStates() + for id, routes := range routeManager.GetClientRoutesWithNetID() { if len(routes) == 0 { continue } r := routes[0] + domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) netStr := r.Network.String() + if r.IsDynamic() { netStr = r.Domains.SafeString() } - peer, err := c.recorder.GetPeer(routes[0].Peer) + routePeer, err := c.recorder.GetPeer(routes[0].Peer) if err != nil { log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } network := Network{ - Name: string(id), - Network: netStr, - Peer: peer.FQDN, - Status: peer.ConnStatus.String(), + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: routeSelector.IsSelected(id), + Domains: domains, } networkArray.Add(network) } @@ -253,6 +295,69 @@ func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } +func (c *Client) toggleRoute(command routeCommand) error { + return command.toggleRoute() +} + +func (c *Client) getRouteManager() (routemanager.Manager, error) { + client := c.connectClient + if client == nil { + return nil, fmt.Errorf("not connected") + } + + engine := client.Engine() + if engine == nil { + return nil, fmt.Errorf("engine is not running") + } + + manager := engine.GetRouteManager() + if manager == nil { + return nil, fmt.Errorf("could not get route manager") + } + + return manager, nil +} + +func (c *Client) SelectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(selectRouteCommand{route: route, manager: manager}) +} + +func (c *Client) DeselectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(deselectRouteCommand{route: route, manager: manager}) +} + +// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain +// with its resolved IP addresses from the provided resolvedDomains map. +func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains { + domains := NetworkDomains{} + + for _, d := range route.Domains { + networkDomain := NetworkDomain{ + Address: d.SafeString(), + } + + if info, exists := resolvedDomains[d]; exists { + for _, prefix := range info.Prefixes { + networkDomain.addResolvedIP(prefix.Addr().String()) + } + } + + domains.Add(&networkDomain) + } + + return domains +} + func exportEnvList(list *EnvList) { if list == nil { return diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..4d4c7a650 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -32,7 +32,7 @@ type ErrListener interface { // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { - Open(string) + Open(url string, userCode string) OnLoginSuccess() } @@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) { go func() { - err := a.login(urlOpener) + err := a.login(urlOpener, isAndroidTV) if err != nil { resultListener.OnError(err) } else { @@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { }() } -func (a *Auth) login(urlOpener URLOpener) error { +func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { var needsLogin bool // check if we need to generate JWT token @@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error { jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener) + tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error { return nil } -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) +func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "") if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) } - go urlOpener.Open(flowInfo.VerificationURIComplete) + go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) diff --git a/client/android/network_domains.go b/client/android/network_domains.go new file mode 100644 index 000000000..a459bdc23 --- /dev/null +++ b/client/android/network_domains.go @@ -0,0 +1,56 @@ +//go:build android + +package android + +import "fmt" + +type ResolvedIPs struct { + resolvedIPs []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.resolvedIPs = append(r.resolvedIPs, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) (string, error) { + if i < 0 || i >= len(r.resolvedIPs) { + return "", fmt.Errorf("%d is out of range", i) + } + return r.resolvedIPs[i], nil +} + +func (r *ResolvedIPs) Size() int { + return len(r.resolvedIPs) +} + +type NetworkDomain struct { + Address string + resolvedIPs ResolvedIPs +} + +func (d *NetworkDomain) addResolvedIP(resolvedIP string) { + d.resolvedIPs.Add(resolvedIP) +} + +func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type NetworkDomains struct { + domains []*NetworkDomain +} + +func (n *NetworkDomains) Add(domain *NetworkDomain) { + n.domains = append(n.domains, domain) +} + +func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) { + if i < 0 || i >= len(n.domains) { + return nil, fmt.Errorf("%d is out of range", i) + } + return n.domains[i], nil +} + +func (n *NetworkDomains) Size() int { + return len(n.domains) +} diff --git a/client/android/networks.go b/client/android/networks.go index aa130420b..3c3a25939 100644 --- a/client/android/networks.go +++ b/client/android/networks.go @@ -3,10 +3,16 @@ package android type Network struct { - Name string - Network string - Peer string - Status string + Name string + Network string + Peer string + Status string + IsSelected bool + Domains NetworkDomains +} + +func (n Network) GetNetworkDomains() *NetworkDomains { + return &n.Domains } type NetworkArray struct { diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 1f5564c72..b03947da1 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build android + package android // PeerInfo describe information about the peers. It designed for the UI usage @@ -5,6 +7,11 @@ type PeerInfo struct { IP string FQDN string ConnStatus string // Todo replace to enum + Routes PeerRoutes +} + +func (p *PeerInfo) GetPeerRoutes() *PeerRoutes { + return &p.Routes } // PeerInfoArray is a wrapper of []PeerInfo diff --git a/client/android/peer_routes.go b/client/android/peer_routes.go new file mode 100644 index 000000000..bb46d609f --- /dev/null +++ b/client/android/peer_routes.go @@ -0,0 +1,20 @@ +//go:build android + +package android + +import "fmt" + +type PeerRoutes struct { + routes []string +} + +func (p *PeerRoutes) Get(i int) (string, error) { + if i < 0 || i >= len(p.routes) { + return "", fmt.Errorf("%d is out of range", i) + } + return p.routes[i], nil +} + +func (p *PeerRoutes) Size() int { + return len(p.routes) +} diff --git a/client/android/platform_files.go b/client/android/platform_files.go new file mode 100644 index 000000000..f0c369750 --- /dev/null +++ b/client/android/platform_files.go @@ -0,0 +1,10 @@ +//go:build android + +package android + +// PlatformFiles groups paths to files used internally by the engine that can't be created/modified +// at their default locations due to android OS restrictions. +type PlatformFiles interface { + ConfigurationFilePath() string + StateFilePath() string +} diff --git a/client/android/preferences.go b/client/android/preferences.go index 9a5d6bb21..c3c8eb3fb 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) { p.configInput.ServerSSHAllowed = &allowed } +// GetEnableSSHRoot reads SSH root login setting from config file +func (p *Preferences) GetEnableSSHRoot() (bool, error) { + if p.configInput.EnableSSHRoot != nil { + return *p.configInput.EnableSSHRoot, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRoot == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRoot, err +} + +// SetEnableSSHRoot stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRoot(enabled bool) { + p.configInput.EnableSSHRoot = &enabled +} + +// GetEnableSSHSFTP reads SSH SFTP setting from config file +func (p *Preferences) GetEnableSSHSFTP() (bool, error) { + if p.configInput.EnableSSHSFTP != nil { + return *p.configInput.EnableSSHSFTP, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHSFTP == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHSFTP, err +} + +// SetEnableSSHSFTP stores the given value and waits for commit +func (p *Preferences) SetEnableSSHSFTP(enabled bool) { + p.configInput.EnableSSHSFTP = &enabled +} + +// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file +func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) { + if p.configInput.EnableSSHLocalPortForwarding != nil { + return *p.configInput.EnableSSHLocalPortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHLocalPortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHLocalPortForwarding, err +} + +// SetEnableSSHLocalPortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) { + p.configInput.EnableSSHLocalPortForwarding = &enabled +} + +// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file +func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) { + if p.configInput.EnableSSHRemotePortForwarding != nil { + return *p.configInput.EnableSSHRemotePortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRemotePortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRemotePortForwarding, err +} + +// SetEnableSSHRemotePortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) { + p.configInput.EnableSSHRemotePortForwarding = &enabled +} + // GetBlockInbound reads block inbound setting from config file func (p *Preferences) GetBlockInbound() (bool, error) { if p.configInput.BlockInbound != nil { diff --git a/client/android/profile_manager.go b/client/android/profile_manager.go new file mode 100644 index 000000000..60e4d5c32 --- /dev/null +++ b/client/android/profile_manager.go @@ -0,0 +1,257 @@ +//go:build android + +package android + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +const ( + // Android-specific config filename (different from desktop default.json) + defaultConfigFilename = "netbird.cfg" + // Subdirectory for non-default profiles (must match Java Preferences.java) + profilesSubdir = "profiles" + // Android uses a single user context per app (non-empty username required by ServiceManager) + androidUsername = "android" +) + +// Profile represents a profile for gomobile +type Profile struct { + Name string + IsActive bool +} + +// ProfileArray wraps profiles for gomobile compatibility +type ProfileArray struct { + items []*Profile +} + +// Length returns the number of profiles +func (p *ProfileArray) Length() int { + return len(p.items) +} + +// Get returns the profile at index i +func (p *ProfileArray) Get(i int) *Profile { + if i < 0 || i >= len(p.items) { + return nil + } + return p.items[i] +} + +/* + +/data/data/io.netbird.client/files/ ← configDir parameter +├── netbird.cfg ← Default profile config +├── state.json ← Default profile state +├── active_profile.json ← Active profile tracker (JSON with Name + Username) +└── profiles/ ← Subdirectory for non-default profiles + ├── work.json ← Work profile config + ├── work.state.json ← Work profile state + ├── personal.json ← Personal profile config + └── personal.state.json ← Personal profile state +*/ + +// ProfileManager manages profiles for Android +// It wraps the internal profilemanager to provide Android-specific behavior +type ProfileManager struct { + configDir string + serviceMgr *profilemanager.ServiceManager +} + +// NewProfileManager creates a new profile manager for Android +func NewProfileManager(configDir string) *ProfileManager { + // Set the default config path for Android (stored in root configDir, not profiles/) + defaultConfigPath := filepath.Join(configDir, defaultConfigFilename) + + // Set global paths for Android + profilemanager.DefaultConfigPathDir = configDir + profilemanager.DefaultConfigPath = defaultConfigPath + profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json") + + // Create ServiceManager with profiles/ subdirectory + // This avoids modifying the global ConfigDirOverride for profile listing + profilesDir := filepath.Join(configDir, profilesSubdir) + serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir) + + return &ProfileManager{ + configDir: configDir, + serviceMgr: serviceMgr, + } +} + +// ListProfiles returns all available profiles +func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) { + // Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive) + internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername) + if err != nil { + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + // Convert internal profiles to Android Profile type + var profiles []*Profile + for _, p := range internalProfiles { + profiles = append(profiles, &Profile{ + Name: p.Name, + IsActive: p.IsActive, + }) + } + + return &ProfileArray{items: profiles}, nil +} + +// GetActiveProfile returns the currently active profile name +func (pm *ProfileManager) GetActiveProfile() (string, error) { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + activeState, err := pm.serviceMgr.GetActiveProfileState() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return activeState.Name, nil +} + +// SwitchProfile switches to a different profile +func (pm *ProfileManager) SwitchProfile(profileName string) error { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profileName, + Username: androidUsername, + }) + if err != nil { + return fmt.Errorf("failed to switch profile: %w", err) + } + + log.Infof("switched to profile: %s", profileName) + return nil +} + +// AddProfile creates a new profile +func (pm *ProfileManager) AddProfile(profileName string) error { + // Use ServiceManager (creates profile in profiles/ directory) + if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to add profile: %w", err) + } + + log.Infof("created new profile: %s", profileName) + return nil +} + +// LogoutProfile logs out from a profile (clears authentication) +func (pm *ProfileManager) LogoutProfile(profileName string) error { + profileName = sanitizeProfileName(profileName) + + configPath, err := pm.getProfileConfigPath(profileName) + if err != nil { + return err + } + + // Check if profile exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return fmt.Errorf("profile '%s' does not exist", profileName) + } + + // Read current config using internal profilemanager + config, err := profilemanager.ReadConfig(configPath) + if err != nil { + return fmt.Errorf("failed to read profile config: %w", err) + } + + // Clear authentication by removing private key and SSH key + config.PrivateKey = "" + config.SSHKey = "" + + // Save config using internal profilemanager + if err := profilemanager.WriteOutConfig(configPath, config); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + log.Infof("logged out from profile: %s", profileName) + return nil +} + +// RemoveProfile deletes a profile +func (pm *ProfileManager) RemoveProfile(profileName string) error { + // Use ServiceManager (removes profile from profiles/ directory) + if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to remove profile: %w", err) + } + + log.Infof("removed profile: %s", profileName) + return nil +} + +// getProfileConfigPath returns the config file path for a profile +// This is needed for Android-specific path handling (netbird.cfg for default profile) +func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + // Android uses netbird.cfg for default profile instead of default.json + // Default profile is stored in root configDir, not in profiles/ + return filepath.Join(pm.configDir, defaultConfigFilename), nil + } + + // Non-default profiles are stored in profiles subdirectory + // This matches the Java Preferences.java expectation + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".json"), nil +} + +// GetConfigPath returns the config file path for a given profile +// Java should call this instead of constructing paths with Preferences.configFile() +func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) { + return pm.getProfileConfigPath(profileName) +} + +// GetStateFilePath returns the state file path for a given profile +// Java should call this instead of constructing paths with Preferences.stateFile() +func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + return filepath.Join(pm.configDir, "state.json"), nil + } + + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".state.json"), nil +} + +// GetActiveConfigPath returns the config file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile() +func (pm *ProfileManager) GetActiveConfigPath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetConfigPath(activeProfile) +} + +// GetActiveStateFilePath returns the state file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile() +func (pm *ProfileManager) GetActiveStateFilePath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetStateFilePath(activeProfile) +} + +// sanitizeProfileName removes invalid characters from profile name +func sanitizeProfileName(name string) string { + // Keep only alphanumeric, underscore, and hyphen + var result strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '_' || r == '-' { + result.WriteRune(r) + } + } + return result.String() +} diff --git a/client/android/route_command.go b/client/android/route_command.go new file mode 100644 index 000000000..b47d5ca6c --- /dev/null +++ b/client/android/route_command.go @@ -0,0 +1,67 @@ +//go:build android + +package android + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/route" +) + +func executeRouteToggle(id string, manager routemanager.Manager, + operationName string, + routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error { + netID := route.NetID(id) + routes := []route.NetID{netID} + + log.Debugf("%s with id: %s", operationName, id) + + if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when %s: %s", operationName, err) + return fmt.Errorf("error %s: %w", operationName, err) + } + + manager.TriggerSelection(manager.GetClientRoutes()) + + return nil +} + +type routeCommand interface { + toggleRoute() error +} + +type selectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (s selectRouteCommand) toggleRoute() error { + routeSelector := s.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error { + return routeSelector.SelectRoutes(routes, true, allRoutes) + } + + return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation) +} + +type deselectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (d deselectRouteCommand) toggleRoute() error { + routeSelector := d.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes) +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/login.go b/client/cmd/login.go index 675b5fdf8..1ddddd3f1 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -10,7 +10,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -105,6 +104,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str Username: &username, } + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { loginRequest.OptionalPreSharedKey = &preSharedKey } @@ -258,7 +264,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, return fmt.Errorf("read config file %s: %v", configFilePath, err) } - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -286,7 +292,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo return nil } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { needsLogin := false err := WithBackOff(func() error { @@ -303,7 +309,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken := "" if setupKey == "" && needsLogin { - tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) + tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -332,8 +338,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { + hint := "" + pm := profilemanager.NewProfileManager() + profileState, err := pm.GetProfileState(profileName) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + hint = profileState.Email + } + + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint) if err != nil { return nil, err } @@ -374,7 +389,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := util.OpenBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } diff --git a/client/cmd/root.go b/client/cmd/root.go index 11e5228f1..30120c196 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -35,7 +35,6 @@ const ( wireguardPortFlag = "wireguard-port" networkMonitorFlag = "network-monitor" disableAutoConnectFlag = "disable-auto-connect" - serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" enableLazyConnectionFlag = "enable-lazy-connection" @@ -64,7 +63,6 @@ var ( customDNSAddress string rosenpassEnabled bool rosenpassPermissive bool - serverSSHAllowed bool interfaceName string wireguardPort uint16 networkMonitor bool @@ -87,6 +85,9 @@ var ( // Execute executes the root command. func Execute() error { + if isUpdateBinary() { + return updateCmd.Execute() + } return rootCmd.Execute() } @@ -176,7 +177,6 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") - upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 075ead44e..f6828d96a 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "runtime" + log "github.com/sirupsen/logrus" + "github.com/kardianos/service" "github.com/spf13/cobra" @@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { svcConfig.Option["LogDirectory"] = dir } } + + if err := configureSystemdNetworkd(); err != nil { + log.Warnf("failed to configure systemd-networkd: %v", err) + } } if runtime.GOOS == "windows" { @@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{ return fmt.Errorf("uninstall service: %w", err) } + if runtime.GOOS == "linux" { + if err := cleanupSystemdNetworkd(); err != nil { + log.Warnf("failed to cleanup systemd-networkd configuration: %v", err) + } + } + cmd.Println("NetBird service has been uninstalled") return nil }, @@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) { return status == service.StatusRunning, nil } + +const ( + networkdConf = "/etc/systemd/networkd.conf" + networkdConfDir = "/etc/systemd/networkd.conf.d" + networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" + networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing +# routes and policy rules managed by NetBird. + +[Network] +ManageForeignRoutes=no +ManageForeignRoutingPolicyRules=no +` +) + +// configureSystemdNetworkd creates a drop-in configuration file to prevent +// systemd-networkd from removing NetBird's routes and policy rules. +func configureSystemdNetworkd() error { + if _, err := os.Stat(networkdConf); os.IsNotExist(err) { + log.Debug("systemd-networkd not in use, skipping configuration") + return nil + } + + // nolint:gosec // standard networkd permissions + if err := os.MkdirAll(networkdConfDir, 0755); err != nil { + return fmt.Errorf("create networkd.conf.d directory: %w", err) + } + + // nolint:gosec // standard networkd permissions + if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { + return fmt.Errorf("write networkd configuration: %w", err) + } + + return nil +} + +// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file. +func cleanupSystemdNetworkd() error { + if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) { + return nil + } + + if err := os.Remove(networkdConfFile); err != nil { + return fmt.Errorf("remove networkd configuration: %w", err) + } + + return nil +} diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go new file mode 100644 index 000000000..5e656650b --- /dev/null +++ b/client/cmd/signer/artifactkey.go @@ -0,0 +1,176 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + bundlePubKeysRootPrivKeyFile string + bundlePubKeysPubKeyFiles []string + bundlePubKeysFile string + + createArtifactKeyRootPrivKeyFile string + createArtifactKeyPrivKeyFile string + createArtifactKeyPubKeyFile string + createArtifactKeyExpiration time.Duration +) + +var createArtifactKeyCmd = &cobra.Command{ + Use: "create-artifact-key", + Short: "Create a new artifact signing key", + Long: `Generate a new artifact signing key pair signed by the root private key. +The artifact key will be used to sign software artifacts/updates.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if createArtifactKeyExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil { + return fmt.Errorf("failed to create artifact key: %w", err) + } + return nil + }, +} + +var bundlePubKeysCmd = &cobra.Command{ + Use: "bundle-pub-keys", + Short: "Bundle multiple artifact public keys into a signed package", + Long: `Bundle one or more artifact public keys into a signed package using the root private key. +This command is typically used to distribute or authorize a set of valid artifact signing keys.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(bundlePubKeysPubKeyFiles) == 0 { + return fmt.Errorf("at least one --artifact-pub-key-file must be provided") + } + + if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil { + return fmt.Errorf("failed to bundle public keys: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createArtifactKeyCmd) + + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved") + createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)") + + if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(fmt.Errorf("mark expiration as required: %w", err)) + } + + rootCmd.AddCommand(bundlePubKeysCmd) + + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle") + bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)") + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved") + + if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil { + panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err)) + } +} + +func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error { + cmd.Println("Creating new artifact signing key...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration) + if err != nil { + return fmt.Errorf("generate artifact key: %w", err) + } + + if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err) + } + + if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err) + } + + signatureFile := artifactPubKeyFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Artifact key created successfully.\n") + cmd.Printf("%s\n", artifactKey.String()) + return nil +} + +func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error { + cmd.Println("📦 Bundling public keys into signed package...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles)) + for _, pubFile := range artifactPubKeyFiles { + pubPem, err := os.ReadFile(pubFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + pk, err := reposign.ParseArtifactPubKey(pubPem) + if err != nil { + return fmt.Errorf("failed to parse artifact key: %w", err) + } + publicKeys = append(publicKeys, pk) + } + + parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys) + if err != nil { + return fmt.Errorf("bundle artifact keys: %w", err) + } + + if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil { + return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err) + } + + signatureFile := bundlePubKeysFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles)) + return nil +} diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go new file mode 100644 index 000000000..881be9367 --- /dev/null +++ b/client/cmd/signer/artifactsign.go @@ -0,0 +1,276 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY" +) + +var ( + signArtifactPrivKeyFile string + signArtifactArtifactFile string + + verifyArtifactPubKeyFile string + verifyArtifactFile string + verifyArtifactSignatureFile string + + verifyArtifactKeyPubKeyFile string + verifyArtifactKeyRootPubKeyFile string + verifyArtifactKeySignatureFile string + verifyArtifactKeyRevocationFile string +) + +var signArtifactCmd = &cobra.Command{ + Use: "sign-artifact", + Short: "Sign an artifact using an artifact private key", + Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key. +This command produces a detached signature that can be verified using the corresponding artifact public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil { + return fmt.Errorf("failed to sign artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactCmd = &cobra.Command{ + Use: "verify-artifact", + Short: "Verify an artifact signature using an artifact public key", + Long: `Verify a software artifact signature using the artifact's public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil { + return fmt.Errorf("failed to verify artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactKeyCmd = &cobra.Command{ + Use: "verify-artifact-key", + Short: "Verify an artifact public key was signed by a root key", + Long: `Verify that an artifact public key (or bundle) was properly signed by a root key. +This validates the chain of trust from the root key to the artifact key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil { + return fmt.Errorf("failed to verify artifact key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(signArtifactCmd) + rootCmd.AddCommand(verifyArtifactCmd) + rootCmd.AddCommand(verifyArtifactKeyCmd) + + signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey)) + signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed") + + // artifact-file is required, but artifact-key-file can come from env var + if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + + verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file") + + if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } + + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)") + + if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil { + panic(fmt.Errorf("mark root-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } +} + +func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error { + cmd.Println("🖋️ Signing artifact...") + + // Load private key from env var or file + var privKeyPEM []byte + var err error + + if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" { + // Use key from environment variable + privKeyPEM = []byte(envKey) + } else if privKeyFile != "" { + // Fall back to file + privKeyPEM, err = os.ReadFile(privKeyFile) + if err != nil { + return fmt.Errorf("read private key file: %w", err) + } + } else { + return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey) + } + + privateKey, err := reposign.ParseArtifactKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact private key: %w", err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + signature, err := reposign.SignData(privateKey, artifactData) + if err != nil { + return fmt.Errorf("sign artifact: %w", err) + } + + sigFile := artifactFile + ".sig" + if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", sigFile, err) + } + + cmd.Printf("✅ Artifact signed successfully.\n") + cmd.Printf("Signature file: %s\n", sigFile) + return nil +} + +func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error { + cmd.Println("🔍 Verifying artifact...") + + // Read artifact public key + pubKeyPEM, err := os.ReadFile(pubKeyFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact public key: %w", err) + } + + // Read artifact data + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate artifact + if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil { + return fmt.Errorf("artifact verification failed: %w", err) + } + + cmd.Println("✅ Artifact signature is valid") + cmd.Printf("Artifact: %s\n", artifactFile) + cmd.Printf("Signed by key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + return nil +} + +func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error { + cmd.Println("🔍 Verifying artifact key...") + + // Read artifact key data + artifactKeyData, err := os.ReadFile(artifactKeyFile) + if err != nil { + return fmt.Errorf("read artifact key file: %w", err) + } + + // Read root public key(s) + rootKeyData, err := os.ReadFile(rootKeyFile) + if err != nil { + return fmt.Errorf("read root key file: %w", err) + } + + rootPublicKeys, err := parseRootPublicKeys(rootKeyData) + if err != nil { + return fmt.Errorf("failed to parse root public key(s): %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Read optional revocation list + var revocationList *reposign.RevocationList + if revocationFile != "" { + revData, err := os.ReadFile(revocationFile) + if err != nil { + return fmt.Errorf("read revocation file: %w", err) + } + + revocationList, err = reposign.ParseRevocationList(revData) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + } + + // Validate artifact key(s) + validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList) + if err != nil { + return fmt.Errorf("artifact key verification failed: %w", err) + } + + cmd.Println("✅ Artifact key(s) verified successfully") + cmd.Printf("Signed by root key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys)) + for i, key := range validKeys { + cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID) + cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST")) + if !key.Metadata.ExpiresAt.IsZero() { + cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST")) + } else { + cmd.Printf(" Expires: Never\n") + } + } + return nil +} + +// parseRootPublicKeys parses a root public key from PEM data +func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) { + key, err := reposign.ParseRootPublicKey(data) + if err != nil { + return nil, err + } + return []reposign.PublicKey{key}, nil +} diff --git a/client/cmd/signer/main.go b/client/cmd/signer/main.go new file mode 100644 index 000000000..407093d07 --- /dev/null +++ b/client/cmd/signer/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "signer", + Short: "A CLI tool for managing cryptographic keys and artifacts", + Long: `signer is a command-line tool that helps you manage +root keys, artifact keys, and revocation lists securely.`, +} + +func main() { + if err := rootCmd.Execute(); err != nil { + rootCmd.Println(err) + os.Exit(1) + } +} diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go new file mode 100644 index 000000000..1d84b65c3 --- /dev/null +++ b/client/cmd/signer/revocation.go @@ -0,0 +1,220 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year +) + +var ( + keyID string + revocationListFile string + privateRootKeyFile string + publicRootKeyFile string + signatureFile string + expirationDuration time.Duration +) + +var createRevocationListCmd = &cobra.Command{ + Use: "create-revocation-list", + Short: "Create a new revocation list signed by the private root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile) + }, +} + +var extendRevocationListCmd = &cobra.Command{ + Use: "extend-revocation-list", + Short: "Extend an existing revocation list with a given key ID", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile) + }, +} + +var verifyRevocationListCmd = &cobra.Command{ + Use: "verify-revocation-list", + Short: "Verify a revocation list signature using the public root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile) + }, +} + +func init() { + rootCmd.AddCommand(createRevocationListCmd) + rootCmd.AddCommand(extendRevocationListCmd) + rootCmd.AddCommand(verifyRevocationListCmd) + + createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for") + extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file") + verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file") + verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file") + if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil { + panic(err) + } +} + +func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration) + if err != nil { + return fmt.Errorf("failed to create revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list created successfully") + return nil +} + +func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + rl, err := reposign.ParseRevocationList(rlBytes) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + + kid, err := reposign.ParseKeyID(keyID) + if err != nil { + return fmt.Errorf("invalid key ID: %w", err) + } + + newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration) + if err != nil { + return fmt.Errorf("failed to extend revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list extended successfully") + return nil +} + +func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error { + // Read revocation list file + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + // Read signature file + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("failed to read signature file: %w", err) + } + + // Read public root key file + pubKeyPEM, err := os.ReadFile(publicRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read public root key file: %w", err) + } + + // Parse public root key + publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse public root key: %w", err) + } + + // Parse signature + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate revocation list + rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature) + if err != nil { + return fmt.Errorf("failed to validate revocation list: %w", err) + } + + // Display results + cmd.Println("✅ Revocation list signature is valid") + cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339)) + cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339)) + cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked)) + + if len(rl.Revoked) > 0 { + cmd.Println("\nRevoked Keys:") + for keyID, revokedTime := range rl.Revoked { + cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339)) + } + } + + return nil +} + +func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error { + if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil { + return fmt.Errorf("failed to write revocation list file: %w", err) + } + if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil { + return fmt.Errorf("failed to write signature file: %w", err) + } + return nil +} diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go new file mode 100644 index 000000000..78ac36b41 --- /dev/null +++ b/client/cmd/signer/rootkey.go @@ -0,0 +1,74 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + privKeyFile string + pubKeyFile string + rootExpiration time.Duration +) + +var createRootKeyCmd = &cobra.Command{ + Use: "create-root-key", + Short: "Create a new root key pair", + Long: `Create a new root key pair and specify an expiration time for it.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + // Validate expiration + if rootExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + // Run main logic + if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil { + return fmt.Errorf("failed to generate root key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createRootKeyCmd) + createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file") + createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file") + createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)") + + if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(err) + } +} + +func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error { + rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration) + if err != nil { + return fmt.Errorf("generate root key: %w", err) + } + + // Write private key + if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", privKeyFile, err) + } + + // Write public key + if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err) + } + + cmd.Printf("%s\n\n", rk.String()) + cmd.Printf("✅ Root key pair generated successfully.\n") + return nil +} diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 5358ddacb..525bcdef1 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -3,125 +3,849 @@ package cmd import ( "context" "errors" + "flag" "fmt" + "net" "os" "os/signal" + "os/user" + "slices" + "strconv" "strings" "syscall" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/internal/profilemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + sshproxy "github.com/netbirdio/netbird/client/ssh/proxy" + sshserver "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/util" ) -var ( - port int - userName = "root" - host string +const ( + sshUsernameDesc = "SSH username" + hostArgumentRequired = "host argument required" + + serverSSHAllowedFlag = "allow-server-ssh" + enableSSHRootFlag = "enable-ssh-root" + enableSSHSFTPFlag = "enable-ssh-sftp" + enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" + enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" + disableSSHAuthFlag = "disable-ssh-auth" + sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl" ) -var sshCmd = &cobra.Command{ - Use: "ssh [user@]host", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errors.New("requires a host argument") - } +var ( + port int + username string + host string + command string + localForwards []string + remoteForwards []string + strictHostKeyChecking bool + knownHostsFile string + identityFile string + skipCachedToken bool + requestPTY bool + sshNoBrowser bool +) - split := strings.Split(args[0], "@") - if len(split) == 2 { - userName = split[0] - host = split[1] - } else { - host = args[0] - } +var ( + serverSSHAllowed bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int +) - return nil - }, - Short: "Connect to a remote SSH server", - RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - SetFlagsFromEnvVars(cmd) +func init() { + upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer") + upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication") + upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)") - cmd.SetOut(cmd.OutOrStdout()) + sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") + sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) + sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)") + sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation") + sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") + sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") + sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") + _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") + sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc) - err := util.InitLog(logLevel, util.LogConsole) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } + sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") + sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") - if !util.IsAdmin() { - cmd.Printf("error: you must have Administrator privileges to run this command\n") - return nil - } - - ctx := internal.CtxInitState(cmd.Context()) - - sm := profilemanager.NewServiceManager(configPath) - activeProf, err := sm.GetActiveProfileState() - if err != nil { - return fmt.Errorf("get active profile: %v", err) - } - profPath, err := activeProf.FilePath() - if err != nil { - return fmt.Errorf("get active profile path: %v", err) - } - - config, err := profilemanager.ReadConfig(profPath) - if err != nil { - return fmt.Errorf("read profile config: %v", err) - } - - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) - sshctx, cancel := context.WithCancel(ctx) - - go func() { - // blocking - if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { - cmd.Printf("Error: %v\n", err) - os.Exit(1) - } - cancel() - }() - - select { - case <-sig: - cancel() - case <-sshctx.Done(): - } - - return nil - }, + sshCmd.AddCommand(sshSftpCmd) + sshCmd.AddCommand(sshProxyCmd) + sshCmd.AddCommand(sshDetectCmd) } -func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error { - c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey) - if err != nil { - cmd.Printf("Error: %v\n", err) - cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + - "\nYou can verify the connection by running:\n\n" + - " netbird status\n\n") - return err - } - go func() { - <-ctx.Done() - err = c.Close() - if err != nil { - return +var sshCmd = &cobra.Command{ + Use: "ssh [flags] [user@]host [command]", + Short: "Connect to a NetBird peer via SSH", + Long: `Connect to a NetBird peer using SSH with support for port forwarding. + +Port Forwarding: + -L [bind_address:]port:host:hostport Local port forwarding + -L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket + -R [bind_address:]port:host:hostport Remote port forwarding + -R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket + +SSH Options: + -p, --port int Remote SSH port (default 22) + -u, --user string SSH username + --login string SSH username (alias for --user) + -t, --tty Force pseudo-terminal allocation + --strict-host-key-checking Enable strict host key checking (default: true) + -o, --known-hosts string Path to known_hosts file + +Examples: + netbird ssh peer-hostname + netbird ssh root@peer-hostname + netbird ssh --login root peer-hostname + netbird ssh peer-hostname ls -la + netbird ssh peer-hostname whoami + netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen + netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo + netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding + netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding + netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces + netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`, + DisableFlagParsing: true, + Args: validateSSHArgsWithoutFlagParsing, + RunE: sshFn, + Aliases: []string{"ssh"}, +} + +func sshFn(cmd *cobra.Command, args []string) error { + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return cmd.Help() } + } + + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(cmd) + + cmd.SetOut(cmd.OutOrStdout()) + + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + if err := util.InitLog(logLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := internal.CtxInitState(cmd.Context()) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) + sshctx, cancel := context.WithCancel(ctx) + + errCh := make(chan error, 1) + go func() { + if err := runSSH(sshctx, host, cmd); err != nil { + errCh <- err + } + cancel() }() - err = c.OpenTerminal() - if err != nil { + select { + case <-sig: + cancel() + <-sshctx.Done() + return nil + case err := <-errCh: return err + case <-sshctx.Done(): } return nil } -func init() { - sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort)) +// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes +func getEnvOrDefault(flagName, defaultValue string) string { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + return envValue + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + return envValue + } + return defaultValue +} + +// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes +func getBoolEnvOrDefault(flagName string, defaultValue bool) bool { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + return defaultValue +} + +// resetSSHGlobals sets SSH globals to their default values +func resetSSHGlobals() { + port = sshserver.DefaultSSHPort + username = "" + host = "" + command = "" + localForwards = nil + remoteForwards = nil + strictHostKeyChecking = true + knownHostsFile = "" + identityFile = "" + sshNoBrowser = false +} + +// parseCustomSSHFlags extracts -L, -R flags and returns filtered args +func parseCustomSSHFlags(args []string) ([]string, []string, []string) { + var localForwardFlags []string + var remoteForwardFlags []string + var filteredArgs []string + + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case strings.HasPrefix(arg, "-L"): + localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags) + case strings.HasPrefix(arg, "-R"): + remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags) + default: + filteredArgs = append(filteredArgs, arg) + } + } + + return filteredArgs, localForwardFlags, remoteForwardFlags +} + +func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) { + if arg == "-L" || arg == "-R" { + if i+1 < len(args) { + flags = append(flags, args[i+1]) + i++ + } + } else if len(arg) > 2 { + flags = append(flags, arg[2:]) + } + return flags, i +} + +// extractGlobalFlags parses global flags that were passed before 'ssh' command +func extractGlobalFlags(args []string) { + sshPos := findSSHCommandPosition(args) + if sshPos == -1 { + return + } + + globalArgs := args[:sshPos] + parseGlobalArgs(globalArgs) +} + +// findSSHCommandPosition locates the 'ssh' command in the argument list +func findSSHCommandPosition(args []string) int { + for i, arg := range args { + if arg == "ssh" { + return i + } + } + return -1 +} + +const ( + configFlag = "config" + logLevelFlag = "log-level" + logFileFlag = "log-file" +) + +// parseGlobalArgs processes the global arguments and sets the corresponding variables +func parseGlobalArgs(globalArgs []string) { + flagHandlers := map[string]func(string){ + configFlag: func(value string) { configPath = value }, + logLevelFlag: func(value string) { logLevel = value }, + logFileFlag: func(value string) { + if !slices.Contains(logFiles, value) { + logFiles = append(logFiles, value) + } + }, + } + + shortFlags := map[string]string{ + "c": configFlag, + "l": logLevelFlag, + } + + for i := 0; i < len(globalArgs); i++ { + arg := globalArgs[i] + + if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled { + i = nextIndex + } + } +} + +// parseFlag handles generic flag parsing for both long and short forms +func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) { + if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + } + + if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + 1 + } + + return false, currentIndex +} + +type parsedFlag struct { + flagName string + value string +} + +// parseEqualsFormat handles --flag=value and -f=value formats +func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if !strings.Contains(arg, "=") { + return parsedFlag{}, false + } + + parts := strings.SplitN(arg, "=", 2) + if len(parts) != 2 { + return parsedFlag{}, false + } + + if strings.HasPrefix(parts[0], "--") { + flagName := strings.TrimPrefix(parts[0], "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: parts[1]}, true + } + } + + if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 { + shortFlag := strings.TrimPrefix(parts[0], "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: parts[1]}, true + } + } + } + + return parsedFlag{}, false +} + +// parseSpacedFormat handles --flag value and -f value formats +func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if currentIndex+1 >= len(args) { + return parsedFlag{}, false + } + + if strings.HasPrefix(arg, "--") { + flagName := strings.TrimPrefix(arg, "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true + } + } + + if strings.HasPrefix(arg, "-") && len(arg) == 2 { + shortFlag := strings.TrimPrefix(arg, "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true + } + } + } + + return parsedFlag{}, false +} + +// createSSHFlagSet creates and configures the flag set for SSH command parsing +// sshFlags contains all SSH-related flags and parameters +type sshFlags struct { + Port int + Username string + Login string + RequestPTY bool + StrictHostKeyChecking bool + KnownHostsFile string + IdentityFile string + SkipCachedToken bool + NoBrowser bool + ConfigPath string + LogLevel string + LocalForwards []string + RemoteForwards []string + Host string + Command string +} + +func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { + defaultConfigPath := getEnvOrDefault("CONFIG", configPath) + defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false) + + fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) + fs.SetOutput(nil) + + flags := &sshFlags{} + + fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port") + fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port") + fs.StringVar(&flags.Username, "u", "", sshUsernameDesc) + fs.StringVar(&flags.Username, "user", "", sshUsernameDesc) + fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)") + fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation") + fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation") + + fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking") + fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file") + fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file") + fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") + fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") + fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc) + + fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level") + fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level") + + return fs, flags +} + +func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + resetSSHGlobals() + + if len(os.Args) > 2 { + extractGlobalFlags(os.Args[1:]) + } + + filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args) + + fs, flags := createSSHFlagSet() + + if err := fs.Parse(filteredArgs); err != nil { + if errors.Is(err, flag.ErrHelp) { + return nil + } + return err + } + + remaining := fs.Args() + if len(remaining) < 1 { + return errors.New(hostArgumentRequired) + } + + port = flags.Port + if flags.Username != "" { + username = flags.Username + } else if flags.Login != "" { + username = flags.Login + } + + requestPTY = flags.RequestPTY + strictHostKeyChecking = flags.StrictHostKeyChecking + knownHostsFile = flags.KnownHostsFile + identityFile = flags.IdentityFile + skipCachedToken = flags.SkipCachedToken + sshNoBrowser = flags.NoBrowser + + if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { + configPath = flags.ConfigPath + } + if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) { + logLevel = flags.LogLevel + } + + localForwards = localForwardFlags + remoteForwards = remoteForwardFlags + + return parseHostnameAndCommand(remaining) +} + +func parseHostnameAndCommand(args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + arg := args[0] + if strings.Contains(arg, "@") { + parts := strings.SplitN(arg, "@", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return errors.New("invalid user@host format") + } + if username == "" { + username = parts[0] + } + host = parts[1] + } else { + host = arg + } + + if username == "" { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + username = sudoUser + } else if currentUser, err := user.Current(); err == nil { + username = currentUser.Username + } else { + username = "root" + } + } + + // Everything after hostname becomes the command + if len(args) > 1 { + command = strings.Join(args[1:], " ") + } + + return nil +} + +func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { + target := fmt.Sprintf("%s:%d", addr, port) + c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ + KnownHostsFile: knownHostsFile, + IdentityFile: identityFile, + DaemonAddr: daemonAddr, + SkipCachedToken: skipCachedToken, + InsecureSkipVerify: !strictHostKeyChecking, + NoBrowser: sshNoBrowser, + }) + + if err != nil { + cmd.Printf("Failed to connect to %s@%s\n", username, target) + cmd.Printf("\nTroubleshooting steps:\n") + cmd.Printf(" 1. Check peer connectivity: netbird status -d\n") + cmd.Printf(" 2. Verify SSH server is enabled on the peer\n") + cmd.Printf(" 3. Ensure correct hostname/IP is used\n") + return fmt.Errorf("dial %s: %w", target, err) + } + + sshCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-sshCtx.Done() + if err := c.Close(); err != nil { + cmd.Printf("Error closing SSH connection: %v\n", err) + } + }() + + if err := startPortForwarding(sshCtx, c, cmd); err != nil { + return fmt.Errorf("start port forwarding: %w", err) + } + + if command != "" { + return executeSSHCommand(sshCtx, c, command) + } + return openSSHTerminal(sshCtx, c) +} + +// executeSSHCommand executes a command over SSH. +func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error { + var err error + if requestPTY { + err = c.ExecuteCommandWithPTY(ctx, command) + } else { + err = c.ExecuteCommandWithIO(ctx, command) + } + + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitErr *ssh.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitStatus()) + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote command exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("execute command: %w", err) + } + return nil +} + +// openSSHTerminal opens an interactive SSH terminal. +func openSSHTerminal(ctx context.Context, c *sshclient.Client) error { + if err := c.OpenTerminal(ctx); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote terminal exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("open terminal: %w", err) + } + return nil +} + +// startPortForwarding starts local and remote port forwarding based on command line flags +func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error { + for _, forward := range localForwards { + if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("local port forward %s: %w", forward, err) + } + } + + for _, forward := range remoteForwards { + if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("remote port forward %s: %w", forward, err) + } + } + + return nil +} + +// parseAndStartLocalForward parses and starts a local port forward (-L) +func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + localAddr, remoteAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) + + go func() { + if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Local port forward error: %v\n", err) + } + }() + + return nil +} + +// parseAndStartRemoteForward parses and starts a remote port forward (-R) +func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + remoteAddr, localAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) + + go func() { + if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Remote port forward error: %v\n", err) + } + }() + + return nil +} + +// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". +// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". +func parsePortForwardSpec(spec string) (string, string, error) { + // Support formats: + // port:host:hostport -> localhost:port -> host:hostport + // host:port:host:hostport -> host:port -> host:hostport + // [host]:port:host:hostport -> [host]:port -> host:hostport + // port:unix_socket_path -> localhost:port -> unix_socket_path + // host:port:unix_socket_path -> host:port -> unix_socket_path + + if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") { + return parseIPv6ForwardSpec(spec) + } + + parts := strings.Split(spec, ":") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec) + } + + switch len(parts) { + case 2: + return parseTwoPartForwardSpec(parts, spec) + case 3: + return parseThreePartForwardSpec(parts) + case 4: + return parseFourPartForwardSpec(parts) + default: + return "", "", fmt.Errorf("invalid port forward specification: %s", spec) + } +} + +// parseTwoPartForwardSpec handles "port:unix_socket" format. +func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) { + if isUnixSocket(parts[1]) { + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + return localAddr, remoteAddr, nil + } + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec) +} + +// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats. +func parseThreePartForwardSpec(parts []string) (string, string, error) { + if isUnixSocket(parts[2]) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + return localAddr, remoteAddr, nil + } + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// parseFourPartForwardSpec handles "host:port:host:hostport" format. +func parseFourPartForwardSpec(parts []string) (string, string, error) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + ":" + parts[3] + return localAddr, remoteAddr, nil +} + +// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format. +func parseIPv6ForwardSpec(spec string) (string, string, error) { + idx := strings.Index(spec, "]:") + if idx == -1 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec) + } + + ipv6Host := spec[:idx+1] + remaining := spec[idx+2:] + + parts := strings.Split(remaining, ":") + if len(parts) != 3 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec) + } + + localAddr := ipv6Host + ":" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// isUnixSocket checks if a path is a Unix socket path. +func isUnixSocket(path string) bool { + return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") +} + +// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +func normalizeLocalHost(host string) string { + if host == "*" { + return "0.0.0.0" + } + return host +} + +var sshProxyCmd = &cobra.Command{ + Use: "proxy ", + Short: "Internal SSH proxy for native SSH client integration", + Long: "Internal command used by SSH ProxyCommand to handle JWT authentication", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshProxyFn, +} + +func sshProxyFn(cmd *cobra.Command, args []string) error { + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + + proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(proxyLogLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + // Check env var for browser setting since this command is invoked via SSH ProxyCommand + // where command-line flags cannot be passed. Default is to open browser. + noBrowser := getBoolEnvOrDefault("NO_BROWSER", false) + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener) + if err != nil { + return fmt.Errorf("create SSH proxy: %w", err) + } + defer func() { + if err := proxy.Close(); err != nil { + log.Debugf("close SSH proxy: %v", err) + } + }() + + if err := proxy.Connect(cmd.Context()); err != nil { + return fmt.Errorf("SSH proxy: %w", err) + } + + return nil +} + +var sshDetectCmd = &cobra.Command{ + Use: "detect ", + Short: "Detect if a host is running NetBird SSH", + Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshDetectFn, +} + +func sshDetectFn(cmd *cobra.Command, args []string) error { + detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(detectLogLevel, "console"); err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + log.Debugf("invalid port %q: %v", portStr, err) + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + if err != nil { + log.Debugf("SSH server detection failed: %v", err) + cancel() + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + cancel() + os.Exit(serverType.ExitCode()) + return nil } diff --git a/client/cmd/ssh_exec_unix.go b/client/cmd/ssh_exec_unix.go new file mode 100644 index 000000000..2412f072c --- /dev/null +++ b/client/cmd/ssh_exec_unix.go @@ -0,0 +1,74 @@ +//go:build unix + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sshExecUID uint32 + sshExecGID uint32 + sshExecGroups []uint + sshExecWorkingDir string + sshExecShell string + sshExecCommand string + sshExecPTY bool +) + +// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping +var sshExecCmd = &cobra.Command{ + Use: "exec", + Short: "Internal SSH execution with privilege dropping (hidden)", + Hidden: true, + RunE: runSSHExec, +} + +func init() { + sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID") + sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID") + sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)") + sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory") + sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute") + sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)") + sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute") + + if err := sshExecCmd.MarkFlagRequired("uid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err) + os.Exit(1) + } + if err := sshExecCmd.MarkFlagRequired("gid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err) + os.Exit(1) + } + + sshCmd.AddCommand(sshExecCmd) +} + +// runSSHExec handles the SSH exec subcommand execution. +func runSSHExec(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sshExecGroups { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sshExecUID, + GID: sshExecGID, + Groups: groups, + WorkingDir: sshExecWorkingDir, + Shell: sshExecShell, + Command: sshExecCommand, + PTY: sshExecPTY, + } + + privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config) + return nil +} diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go new file mode 100644 index 000000000..c06aab017 --- /dev/null +++ b/client/cmd/ssh_sftp_unix.go @@ -0,0 +1,94 @@ +//go:build unix + +package cmd + +import ( + "errors" + "io" + "os" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpUID uint32 + sftpGID uint32 + sftpGroupsInt []uint + sftpWorkingDir string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with privilege dropping (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID") + sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID") + sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)") + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sftpGroupsInt { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sftpUID, + GID: sftpGID, + Groups: groups, + WorkingDir: sftpWorkingDir, + Shell: "", + Command: "", + } + + log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + cmd.PrintErrf("privilege drop failed: %v\n", err) + os.Exit(sshserver.ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Tracef("starting SFTP server with dropped privileges") + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeShellExecFail) + } + + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeSuccess) + return nil +} diff --git a/client/cmd/ssh_sftp_windows.go b/client/cmd/ssh_sftp_windows.go new file mode 100644 index 000000000..ffd2d1148 --- /dev/null +++ b/client/cmd/ssh_sftp_windows.go @@ -0,0 +1,94 @@ +//go:build windows + +package cmd + +import ( + "errors" + "fmt" + "io" + "os" + "os/user" + "strings" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpWorkingDir string + windowsUsername string + windowsDomain string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with user switching for Windows (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") + sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching") + sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + return sftpMainDirect(cmd) +} + +func sftpMainDirect(cmd *cobra.Command) error { + currentUser, err := user.Current() + if err != nil { + cmd.PrintErrf("failed to get current user: %v\n", err) + os.Exit(sshserver.ExitCodeValidationFail) + } + + if windowsUsername != "" { + expectedUsername := windowsUsername + if windowsDomain != "" { + expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername) + } + if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) { + cmd.PrintErrf("user switching failed\n") + os.Exit(sshserver.ExitCodeValidationFail) + } + } + + log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name) + + if sftpWorkingDir != "" { + if err := os.Chdir(sftpWorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Debugf("starting SFTP server") + exitCode := sshserver.ExitCodeSuccess + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + exitCode = sshserver.ExitCodeShellExecFail + } + + if err := sftpServer.Close(); err != nil { + log.Debugf("SFTP server close error: %v", err) + } + + os.Exit(exitCode) + return nil +} diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go new file mode 100644 index 000000000..43291fa87 --- /dev/null +++ b/client/cmd/ssh_test.go @@ -0,0 +1,717 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSHCommand_FlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedUser string + expectedPort int + expectedCmd string + expectError bool + }{ + { + name: "basic host", + args: []string{"hostname"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "user@host format", + args: []string{"user@hostname"}, + expectedHost: "hostname", + expectedUser: "user", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "host with command", + args: []string{"hostname", "echo", "hello"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "echo hello", + }, + { + name: "command with flags should be preserved", + args: []string{"hostname", "ls", "-la", "/tmp"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "ls -la /tmp", + }, + { + name: "double dash separator", + args: []string{"hostname", "--", "ls", "-la"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "-- ls -la", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + // Mock command for testing + cmd := sshCmd + cmd.SetArgs(tt.args) + + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + if tt.expectedUser != "" { + assert.Equal(t, tt.expectedUser, username, "username mismatch") + } + assert.Equal(t, tt.expectedPort, port, "port mismatch") + assert.Equal(t, tt.expectedCmd, command, "command mismatch") + }) + } +} + +func TestSSHCommand_FlagConflictPrevention(t *testing.T) { + // Test that SSH flags don't conflict with command flags + tests := []struct { + name string + args []string + expectedCmd string + description string + }{ + { + name: "ls with -la flags", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + description: "ls flags should be passed to remote command", + }, + { + name: "grep with -r flag", + args: []string{"hostname", "grep", "-r", "pattern", "/path"}, + expectedCmd: "grep -r pattern /path", + description: "grep flags should be passed to remote command", + }, + { + name: "ps with aux flags", + args: []string{"hostname", "ps", "aux"}, + expectedCmd: "ps aux", + description: "ps flags should be passed to remote command", + }, + { + name: "command with double dash", + args: []string{"hostname", "--", "ls", "-la"}, + expectedCmd: "-- ls -la", + description: "double dash should be preserved in command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_NonInteractiveExecution(t *testing.T) { + // Test that commands with arguments should execute the command and exit, + // not drop to an interactive shell + tests := []struct { + name string + args []string + expectedCmd string + shouldExit bool + description string + }{ + { + name: "ls command should execute and exit", + args: []string{"hostname", "ls"}, + expectedCmd: "ls", + shouldExit: true, + description: "ls command should execute and exit, not drop to shell", + }, + { + name: "ls with flags should execute and exit", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + shouldExit: true, + description: "ls with flags should execute and exit, not drop to shell", + }, + { + name: "pwd command should execute and exit", + args: []string{"hostname", "pwd"}, + expectedCmd: "pwd", + shouldExit: true, + description: "pwd command should execute and exit, not drop to shell", + }, + { + name: "echo command should execute and exit", + args: []string{"hostname", "echo", "hello"}, + expectedCmd: "echo hello", + shouldExit: true, + description: "echo command should execute and exit, not drop to shell", + }, + { + name: "no command should open shell", + args: []string{"hostname"}, + expectedCmd: "", + shouldExit: false, + description: "no command should open interactive shell", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // When command is present, it should execute the command and exit + // When command is empty, it should open interactive shell + hasCommand := command != "" + assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior") + }) + } +} + +func TestSSHCommand_FlagHandling(t *testing.T) { + // Test that flags after hostname are not parsed by netbird but passed to SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "ls with -la flag should not be parsed by netbird", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "ls -la should be passed as SSH command, not parsed as netbird flags", + }, + { + name: "command with netbird-like flags should be passed through", + args: []string{"hostname", "echo", "--help"}, + expectedHost: "hostname", + expectedCmd: "echo --help", + expectError: false, + description: "--help should be passed to echo, not parsed by netbird", + }, + { + name: "command with -p flag should not conflict with SSH port flag", + args: []string{"hostname", "ps", "-p", "1234"}, + expectedHost: "hostname", + expectedCmd: "ps -p 1234", + expectError: false, + description: "ps -p should be passed to ps command, not parsed as port", + }, + { + name: "tar with flags should be passed through", + args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"}, + expectedHost: "hostname", + expectedCmd: "tar -czf backup.tar.gz /home", + expectError: false, + description: "tar flags should be passed to tar command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_RegressionFlagParsing(t *testing.T) { + // Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la" + // should not parse -la as netbird flags but pass them to the SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "original issue: ls -la should be preserved", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "The original failing case should now work", + }, + { + name: "ls -l should be preserved", + args: []string{"hostname", "ls", "-l"}, + expectedHost: "hostname", + expectedCmd: "ls -l", + expectError: false, + description: "Single letter flags should be preserved", + }, + { + name: "SSH port flag should work", + args: []string{"-p", "2222", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedCmd: "ls -la", + expectError: false, + description: "SSH -p flag should be parsed, command flags preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // Check port for the test case with -p flag + if len(tt.args) > 0 && tt.args[0] == "-p" { + assert.Equal(t, 2222, port, "port should be parsed from -p flag") + } + }) + } +} + +func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectError bool + description string + }{ + { + name: "local port forwarding -L", + args: []string{"-L", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Single -L flag should be parsed correctly", + }, + { + name: "remote port forwarding -R", + args: []string{"-R", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectError: false, + description: "Single -R flag should be parsed correctly", + }, + { + name: "multiple local port forwards", + args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"}, + expectedRemote: []string{}, + expectError: false, + description: "Multiple -L flags should be parsed correctly", + }, + { + name: "multiple remote port forwards", + args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"}, + expectError: false, + description: "Multiple -R flags should be parsed correctly", + }, + { + name: "mixed local and remote forwards", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectError: false, + description: "Mixed -L and -R flags should be parsed correctly", + }, + { + name: "port forwarding with bind address", + args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"127.0.0.1:8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with bind address should work", + }, + { + name: "port forwarding with command", + args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with command should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + }) + } +} + +func TestParsePortForward(t *testing.T) { + tests := []struct { + name string + spec string + expectedLocal string + expectedRemote string + expectError bool + description string + }{ + { + name: "simple port forward", + spec: "8080:localhost:80", + expectedLocal: "localhost:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Simple port:host:port format should work", + }, + { + name: "port forward with bind address", + spec: "127.0.0.1:8080:localhost:80", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "bind_address:port:host:port format should work", + }, + { + name: "port forward to different host", + spec: "8080:example.com:443", + expectedLocal: "localhost:8080", + expectedRemote: "example.com:443", + expectError: false, + description: "Forwarding to different host should work", + }, + { + name: "port forward with IPv6 (needs bracket support)", + spec: "::1:8080:localhost:80", + expectError: true, + description: "IPv6 without brackets fails as expected (feature to implement)", + }, + { + name: "invalid format - too few parts", + spec: "8080:localhost", + expectError: true, + description: "Invalid format with too few parts should fail", + }, + { + name: "invalid format - too many parts", + spec: "127.0.0.1:8080:localhost:80:extra", + expectError: true, + description: "Invalid format with too many parts should fail", + }, + { + name: "empty spec", + spec: "", + expectError: true, + description: "Empty spec should fail", + }, + { + name: "unix socket local forward", + spec: "8080:/tmp/socket", + expectedLocal: "localhost:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket forwarding should work", + }, + { + name: "unix socket with bind address", + spec: "127.0.0.1:8080:/tmp/socket", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket with bind address should work", + }, + { + name: "wildcard bind all interfaces", + spec: "*:8080:localhost:80", + expectedLocal: "0.0.0.0:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Wildcard * should bind to all interfaces (0.0.0.0)", + }, + { + name: "wildcard for port only", + spec: "8080:*:80", + expectedLocal: "localhost:8080", + expectedRemote: "*:80", + expectError: false, + description: "Wildcard in remote host should be preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec) + + if tt.expectError { + assert.Error(t, err, tt.description) + return + } + + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address") + assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address") + }) + } +} + +func TestSSHCommand_IntegrationPortForwarding(t *testing.T) { + // Integration test for port forwarding with the actual SSH command implementation + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectedCmd string + description string + }{ + { + name: "local forward with command", + args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectedCmd: "echo test", + description: "Local forwarding should work with commands", + }, + { + name: "remote forward with command", + args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectedCmd: "ls -la", + description: "Remote forwarding should work with commands", + }, + { + name: "multiple forwards with user and command", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectedCmd: "ps aux", + description: "Complex case with multiple forwards, user, and command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + assert.Equal(t, tt.expectedCmd, command, tt.description+" - command") + }) + } +} + +func TestSSHCommand_ParameterIsolation(t *testing.T) { + tests := []struct { + name string + args []string + expectedCmd string + }{ + { + name: "cmd flag passed as command", + args: []string{"hostname", "--cmd", "echo test"}, + expectedCmd: "--cmd echo test", + }, + { + name: "uid flag passed as command", + args: []string{"hostname", "--uid", "1000"}, + expectedCmd: "--uid 1000", + }, + { + name: "shell flag passed as command", + args: []string{"hostname", "--shell", "/bin/bash"}, + expectedCmd: "--shell /bin/bash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + require.NoError(t, err) + + assert.Equal(t, "hostname", host) + assert.Equal(t, tt.expectedCmd, command) + }) + } +} + +func TestSSHCommand_InvalidFlagRejection(t *testing.T) { + // Test that invalid flags are properly rejected and not misinterpreted as hostnames + tests := []struct { + name string + args []string + description string + }{ + { + name: "invalid long flag before hostname", + args: []string{"--invalid-flag", "hostname"}, + description: "Invalid flag should return parse error, not treat flag as hostname", + }, + { + name: "invalid short flag before hostname", + args: []string{"-x", "hostname"}, + description: "Invalid short flag should return parse error", + }, + { + name: "invalid flag with value before hostname", + args: []string{"--invalid-option=value", "hostname"}, + description: "Invalid flag with value should return parse error", + }, + { + name: "typo in known flag", + args: []string{"--por", "2222", "hostname"}, + description: "Typo in flag name should return parse error (not silently ignored)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + + // Should return an error for invalid flags + assert.Error(t, err, tt.description) + + // Should not have set host to the invalid flag + assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname") + }) + } +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..06460a6a7 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { case yamlFlag: statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) default: - statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false) + statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) } if err != nil { @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index bd3209605..b9ff35945 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,12 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" @@ -20,8 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -84,7 +88,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, nil @@ -110,13 +113,21 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) + + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index d047c041e..9efc2e60d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) - err = foregroundLogin(ctx, cmd, config, providedSetupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r) + connectClient := internal.NewConnectClient(ctx, config, r, false) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil) @@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ loginRequest.ProfileName = &activeProf.Name loginRequest.Username = &username + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + var loginErr error var loginResp *proto.LoginResponse @@ -348,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro if cmd.Flag(serverSSHAllowedFlag).Changed { req.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + req.EnableSSHRoot = &enableSSHRoot + } + if cmd.Flag(enableSSHSFTPFlag).Changed { + req.EnableSSHSFTP = &enableSSHSFTP + } + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + if cmd.Flag(disableSSHAuthFlag).Changed { + req.DisableSSHAuth = &disableSSHAuth + } + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { log.Errorf("parse interface name: %v", err) @@ -432,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + ic.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + ic.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + ic.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + ic.SSHJWTCacheTTL = &sshJWTCacheTTL + } + if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { return nil, err @@ -532,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + loginRequest.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + loginRequest.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + loginRequest.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + if cmd.Flag(disableAutoConnectFlag).Changed { loginRequest.DisableAutoConnect = &autoConnectDisabled } diff --git a/client/cmd/update.go b/client/cmd/update.go new file mode 100644 index 000000000..dc49b02c3 --- /dev/null +++ b/client/cmd/update.go @@ -0,0 +1,13 @@ +//go:build !windows && !darwin + +package cmd + +import ( + "github.com/spf13/cobra" +) + +var updateCmd *cobra.Command + +func isUpdateBinary() bool { + return false +} diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go new file mode 100644 index 000000000..977875093 --- /dev/null +++ b/client/cmd/update_supported.go @@ -0,0 +1,75 @@ +//go:build windows || darwin + +package cmd + +import ( + "context" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/util" +) + +var ( + updateCmd = &cobra.Command{ + Use: "update", + Short: "Update the NetBird client application", + RunE: updateFunc, + } + + tempDirFlag string + installerFile string + serviceDirFlag string + dryRunFlag bool +) + +func init() { + updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir") + updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file") + updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory") + updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes") +} + +// isUpdateBinary checks if the current executable is named "update" or "update.exe" +func isUpdateBinary() bool { + // Remove extension for cross-platform compatibility + execPath, err := os.Executable() + if err != nil { + return false + } + baseName := filepath.Base(execPath) + name := strings.TrimSuffix(baseName, filepath.Ext(baseName)) + + return name == installer.UpdaterBinaryNameWithoutExtension() +} + +func updateFunc(cmd *cobra.Command, args []string) error { + if err := setupLogToFile(tempDirFlag); err != nil { + return err + } + + log.Infof("updater started: %s", serviceDirFlag) + updater := installer.NewWithDir(tempDirFlag) + if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil { + log.Errorf("failed to update application: %v", err) + return err + } + return nil +} + +func setupLogToFile(dir string) error { + logFile := filepath.Join(dir, installer.LogFile) + + if _, err := os.Stat(logFile); err == nil { + if err := os.Remove(logFile); err != nil { + log.Errorf("failed to remove existing log file: %v\n", err) + } + } + + return util.InitLog(logLevel, util.LogConsole, logFile) +} diff --git a/client/embed/embed.go b/client/embed/embed.go index e918235ed..353c5438f 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -18,12 +18,16 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" ) -var ErrClientAlreadyStarted = errors.New("client already started") -var ErrClientNotStarted = errors.New("client not started") -var ErrConfigNotInitialized = errors.New("config not initialized") +var ( + ErrClientAlreadyStarted = errors.New("client already started") + ErrClientNotStarted = errors.New("client not started") + ErrEngineNotStarted = errors.New("engine not started") + ErrConfigNotInitialized = errors.New("config not initialized") +) // Client manages a netbird embedded client instance. type Client struct { @@ -169,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error { } recorder := peer.NewRecorder(c.config.ManagementURL.String()) - client := internal.NewConnectClient(ctx, c.config, recorder) + client := internal.NewConnectClient(ctx, c.config, recorder, false) // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available @@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) { // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { - c.mu.Lock() - connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, ErrClientNotStarted - } - c.mu.Unlock() - - engine := connect.Engine() - if engine == nil { - return nil, errors.New("engine not started") + engine, err := c.getEngine() + if err != nil { + return nil, err } nsnet, err := engine.GetNet() @@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } +// DialContext dials a network address in the netbird network with context +func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return c.Dial(ctx, network, address) +} + // ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { @@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client { } } -func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { +// VerifySSHHostKey verifies an SSH host key against stored peer keys. +// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, +// ErrNoStoredKey if peer has no stored key, or an error for verification failures. +func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { + engine, err := c.getEngine() + if err != nil { + return err + } + + storedKey, found := engine.GetPeerSSHKey(peerAddress) + if !found { + return sshcommon.ErrPeerNotFound + } + + return sshcommon.VerifyHostKey(storedKey, key, peerAddress) +} + +// getEngine safely retrieves the engine from the client with proper locking. +// Returns ErrClientNotStarted if the client is not started. +// Returns ErrEngineNotStarted if the engine is not available. +func (c *Client) getEngine() (*internal.Engine, error) { c.mu.Lock() connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, netip.Addr{}, errors.New("client not started") - } c.mu.Unlock() + if connect == nil { + return nil, ErrClientNotStarted + } + engine := connect.Engine() if engine == nil { - return nil, netip.Addr{}, errors.New("engine not started") + return nil, ErrEngineNotStarted + } + + return engine, nil +} + +func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { + engine, err := c.getEngine() + if err != nil { + return nil, netip.Addr{}, err } addr, err := engine.Address() diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..24f12bc6d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,13 +15,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..12dcaee8a 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) if !iface.IsUserspaceBind() { return fm, err @@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { - fm, err := createFW(iface) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { + fm, err := createFW(iface, mtu) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) } @@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - return nbiptables.Create(iface) + return nbiptables.Create(iface, mtu) case NFTABLES: log.Info("creating an nftables firewall manager") - return nbnftables.Create(iface) + return nbnftables.Create(iface, mtu) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") return nil, errors.New("no firewall manager found") } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) } if errUsp != nil { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..5ccaf17ba 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -1,13 +1,14 @@ package iptables import ( + "errors" "fmt" "net" "slices" "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -40,19 +41,13 @@ type aclManager struct { } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { - m := &aclManager{ + return &aclManager{ iptablesClient: iptablesClient, wgIface: wgIface, entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), - } - - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - - return m, nil + }, nil } func (m *aclManager) init(stateManager *statemanager.Manager) error { @@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering( specs = append(specs, "-j", actionToStr(action)) if ipsetName != "" { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } // if ruleset already exists it means we already have the firewall rule // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. @@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering( }}, nil } - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %s before use it: %s", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %s before use: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %s before use: %v", ipsetName, err) + } } - if err := ipset.Create(ipsetName); err != nil { - return nil, fmt.Errorf("failed to create ipset: %w", err) + if err := m.createIPSet(ipsetName); err != nil { + return nil, fmt.Errorf("create ipset: %w", err) } - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } ipList := newIpList(ip.String()) @@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("invalid rule type") } + shouldDestroyIpset := false if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { - if err := ipset.Del(r.ipsetName, r.ip); err != nil { - return fmt.Errorf("failed to delete ip from ipset: %w", err) + ip := net.ParseIP(r.ip) + if ip == nil { + return fmt.Errorf("parse IP %s", r.ip) + } + if err := m.delFromIPSet(r.ipsetName, ip); err != nil { + return fmt.Errorf("delete ip from ipset: %w", err) } delete(ipsetList.ips, r.ip) } @@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { // we delete last IP from the set, that means we need to delete // set itself and associated firewall rule too m.ipsetStore.deleteIpset(r.ipsetName) - - if err := ipset.Destroy(r.ipsetName); err != nil { - log.Errorf("delete empty ipset: %v", err) - } + shouldDestroyIpset = true } if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { @@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } + if shouldDestroyIpset { + if err := m.destroyIPSet(r.ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy empty ipset: %v", err) + } else { + log.Errorf("destroy empty ipset: %v", err) + } + } + } + m.updateState() return nil @@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error { } for _, ipsetName := range m.ipsetStore.ipsetNames() { - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + } } - if err := ipset.Destroy(ipsetName); err != nil { - log.Errorf("delete ipset %q during reset: %v", ipsetName, err) + if err := m.destroyIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("destroy ipset %q during reset: %v", ipsetName, err) + } } m.ipsetStore.deleteIpset(ipsetName) } @@ -368,8 +387,8 @@ func (m *aclManager) updateState() { // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { matchByIP := true - // don't use IP matching if IP is ip 0.0.0.0 - if ip.String() == "0.0.0.0" { + // don't use IP matching if IP is 0.0.0.0 + if ip.IsUnspecified() { matchByIP = false } @@ -400,7 +419,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" @@ -417,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return ipsetName + actionSuffix } } + +func (m *aclManager) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (m *aclManager) addToIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add IP to ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) delFromIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + } + + if err := ipset.Del(name, entry); err != nil { + return fmt.Errorf("delete IP from ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) flushIPSet(name string) error { + return ipset.Flush(name) +} + +func (m *aclManager) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..2563a9052 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) @@ -260,6 +261,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index a5cc62feb..6b5401e2b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..1fe4c149f 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -10,7 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/hashicorp/go-multierror" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" @@ -30,17 +30,20 @@ const ( chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" + chainFORWARD = "FORWARD" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" matchSet = "--match-set" @@ -48,6 +51,9 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" fwdSuffix = "_fwd" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) type ruleInfo struct { @@ -77,16 +83,18 @@ type router struct { ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + mtu uint16 stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } -func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + mtu: mtu, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -99,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, }, ) - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - return r, nil } @@ -224,12 +228,12 @@ func (r *router) findSets(rule []string) []string { } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { - if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + if err := r.createIPSet(setName); err != nil { return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { - if err := ipset.AddPrefix(setName, prefix); err != nil { + if err := r.addPrefixToIPSet(setName, prefix); err != nil { return fmt.Errorf("add element to set %s: %w", setName, err) } } @@ -238,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error { } func (r *router) deleteIpSet(setName string) error { - if err := ipset.Destroy(setName); err != nil { + if err := r.destroyIPSet(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } @@ -392,6 +396,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -416,6 +421,7 @@ func (r *router) createContainers() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) @@ -438,6 +444,10 @@ func (r *router) createContainers() error { return fmt.Errorf("add jump rules: %w", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + return nil } @@ -518,6 +528,35 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + // Add jump rule from FORWARD chain in mangle table to our custom chain + jumpRule := []string{ + "-j", chainRTMSSCLAMP, + } + if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil { + return fmt.Errorf("add jump to MSS clamp chain: %w", err) + } + r.rules[jumpMSSClamp] = jumpRule + + ruleOut := []string{ + "-o", r.wgIface.Name(), + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mss), + } + if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil { + return fmt.Errorf("add outbound MSS clamp rule: %w", err) + } + r.rules["mss-clamp-out"] = ruleOut + + return nil +} + func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -558,7 +597,7 @@ func (r *router) addJumpRules() error { } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} { if rule, exists := r.rules[ruleKey]; exists { var table, chain string switch ruleKey { @@ -571,6 +610,9 @@ func (r *router) cleanJumpRules() error { case jumpNatPre: table = tableNat chain = chainPREROUTING + case jumpMSSClamp: + table = tableMangle + chain = chainFORWARD default: return fmt.Errorf("unknown jump rule: %s", ruleKey) } @@ -869,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } - if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { @@ -880,6 +922,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil @@ -899,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } + +func (r *router) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error { + addr := prefix.Addr() + ip := addr.AsSlice() + + entry := &ipset.Entry{ + IP: ip, + CIDR: uint8(prefix.Bits()), + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add prefix to ipset %s: %w", name, err) + } + + return nil +} + +func (r *router) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..6707573be 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { assert.NoError(t, manager.Reset(), "shouldn't return error") }() - // Now 5 rules: // 1. established rule forward in // 2. estbalished rule forward out // 3. jump rule to POST nat chain @@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 7. static return masquerade rule // 8. mangle prerouting mark rule // 9. mangle postrouting mark rule - require.Len(t, manager.rules, 9, "should have created rules map") + // 10. jump rule to MSS clamping chain + // 11. MSS clamping rule for outbound traffic + require.Len(t, manager.rules, 11, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 6ef159e01..c88774c1f 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -11,6 +12,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + ipt, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, @@ -151,14 +154,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -44,16 +53,16 @@ type Manager struct { } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(workTable, wgIface) + m.router, err = newRouter(workTable, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) @@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -376,61 +315,40 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index c7f05dcb7..adec802c8 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) { // This test verifies rule insertion order in nftables peer ACLs // We add accept rule first, then deny rule to test ordering behavior - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr := runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err, "failed to create manager") require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..7f95992da 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -16,6 +16,7 @@ import ( "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -26,18 +27,27 @@ import ( ) const ( - tableNat = "nat" + tableNat = "nat" + tableMangle = "mangle" + tableRaw = "raw" + tableSecurity = "security" + chainNameNatPrerouting = "PREROUTING" chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" chainNameForward = "FORWARD" + chainNameMangleForward = "netbird-mangle-forward" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) const refreshRulesMapError = "refresh rules map: %w" @@ -63,9 +73,10 @@ type router struct { wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool + mtu uint16 } -func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { +func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ conn: &nftables.Conn{}, workTable: workTable, @@ -73,6 +84,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) rules: make(map[string]*nftables.Rule), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), + mtu: mtu, } r.ipsetCounter = refcounter.New( @@ -83,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) var err error r.filterTable, err = r.loadFilterTable() if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, fmt.Errorf("load filter table: %w", err) - } + log.Debugf("ip filter table not found: %v", err) } return r, nil @@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -167,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("unable to list tables: %v", err) + return nil, fmt.Errorf("list tables: %w", err) } for _, table := range tables { @@ -179,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { return nil, errFilterTableNotFound } +func hookName(hook *nftables.ChainHook) string { + if hook == nil { + return "unknown" + } + switch *hook { + case *nftables.ChainHookForward: + return chainNameForward + case *nftables.ChainHookInput: + return chainNameInput + default: + return fmt.Sprintf("hook(%d)", *hook) + } +} + +func familyName(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + default: + return fmt.Sprintf("family(%d)", family) + } +} + func (r *router) createContainers() error { r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, }) - insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) - prio := *nftables.ChainPriorityNATSource - 1 r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, @@ -220,9 +253,24 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) - // Add the single NAT rule that matches on mark - if err := r.addPostroutingRules(); err != nil { - return fmt.Errorf("add single nat rule: %v", err) + r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameMangleForward, + Table: r.workTable, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.addPostroutingRules() + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("initialize tables: %v", err) + } + + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) } if err := r.acceptForwardRules(); err != nil { @@ -230,11 +278,7 @@ func (r *router) createContainers() error { } if err := r.refreshRulesMap(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - if err := r.conn.Flush(); err != nil { - return fmt.Errorf("initialize tables: %v", err) + log.Errorf("failed to refresh rules: %s", err) } return nil @@ -675,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } // addPostroutingRules adds the masquerade rules -func (r *router) addPostroutingRules() error { +func (r *router) addPostroutingRules() { // First masquerade rule for traffic coming in from WireGuard interface exprs := []expr.Any{ // Match on the first fwmark @@ -741,8 +785,83 @@ func (r *router) addPostroutingRules() error { Chain: r.chains[chainNameRoutingNat], Exprs: exprs2, }) +} - return nil +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + exprsOut := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 1, + Mask: []byte{0x02}, + Xor: []byte{0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00}, + }, + &expr.Counter{}, + &expr.Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + &expr.Cmp{ + Op: expr.CmpOpGt, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameMangleForward], + Exprs: exprsOut, + }) + + return r.conn.Flush() } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls @@ -840,41 +959,63 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { + var merr *multierror.Error + + if err := r.acceptFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) acceptFilterTableRules() error { if r.filterTable == nil { - log.Debugf("table 'filter' not found for forward rules, skipping accept rules") return nil } fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { - // filter table exists but iptables is not + // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -886,19 +1027,74 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables. +// This is used when iptables is not available. +func (r *router) acceptFilterRulesNftables(table *nftables.Table) error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter + forwardChain := &nftables.Chain{ + Name: chainNameForward, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + } + r.insertForwardAcceptRules(forwardChain, intf) + + inputChain := &nftables.Chain{ + Name: chainNameInput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + } + r.insertInputAcceptRule(inputChain, intf) + + return r.conn.Flush() +} + +// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables). +// It dynamically finds chains at call time to handle chains that may have been created after startup. +func (r *router) acceptExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + intf := ifname(r.wgIface.Name()) + + for _, chain := range chains { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + continue + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush external chain rules: %w", err) + } + + return nil +} + +func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { iifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -921,82 +1117,188 @@ func (r *router) acceptForwardRulesNftables() error { Data: intf, }, } - - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) - - return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { + inputRule := &nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) +} + +func (r *router) removeAcceptFilterRules() error { + var merr *multierror.Error + + if err := r.removeFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.removeExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) removeFilterTableRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { - log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) +func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { + chains, err := r.conn.ListChainsOfTableFamily(table.Family) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != table.Name { continue } - rules, err := r.conn.GetRules(r.filterTable, chain) - if err != nil { - return fmt.Errorf("get rules: %v", err) + if chain.Name != chainNameForward && chain.Name != chainNameInput { + continue } - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } + if err := r.removeAcceptRulesFromChain(table, chain); err != nil { + return err + } + } + + return r.conn.Flush() +} + +func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error { + rules, err := r.conn.GetRules(table, chain) + if err != nil { + return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err) + } + } + } + return nil +} + +// removeExternalChainsRules removes our accept rules from all external chains. +// This is deterministic - it scans for chains at removal time rather than relying on saved state, +// ensuring cleanup works even after a crash or if chains changed. +func (r *router) removeExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + for _, chain := range chains { + if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil { + log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err) + } + } + + return r.conn.Flush() +} + +// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks. +// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal). +func (r *router) findExternalChains() []*nftables.Chain { + var chains []*nftables.Chain + + families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + + for _, family := range families { + allChains, err := r.conn.ListChainsOfTableFamily(family) + if err != nil { + log.Debugf("list chains for family %d: %v", family, err) + continue + } + + for _, chain := range allChains { + if r.isExternalChain(chain) { + chains = append(chains, chain) } } } - if err := r.conn.Flush(); err != nil { - return fmt.Errorf(flushError, err) - } - - return nil + return chains } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) isExternalChain(chain *nftables.Chain) bool { + if r.workTable != nil && chain.Table.Name == r.workTable.Name { + return false + } + + // Skip all iptables-managed tables in the ip family + if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + return false + } + + if chain.Type != nftables.ChainTypeFilter { + return false + } + + if chain.Hooknum == nil { + return false + } + + return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput +} + +func isIptablesTable(name string) bool { + switch name { + case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity: + return true + } + return false +} + +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } @@ -1056,7 +1358,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf(" unable to list rules: %v", err) + return fmt.Errorf("list rules: %w", err) } for _, rule := range rules { if len(rule.UserData) > 0 { @@ -1350,6 +1652,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 4fdbf3505..3531b014b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index f805623d6..48b7b3741 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,6 +3,7 @@ package nftables import ( "fmt" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -10,6 +11,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + nft, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..4e22bde3f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "errors" "fmt" "net" @@ -27,7 +28,18 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const ( + layerTypeAll = 0 + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 +) + +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. @@ -36,6 +48,9 @@ const ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" + // EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic. + EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING" + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" @@ -109,6 +124,17 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex + + netstackServices map[serviceKey]struct{} + netstackServiceMutex sync.RWMutex + + mtu uint16 + mssClampValue uint16 + mssClampEnabled bool } // decoder for packages @@ -122,19 +148,21 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - return create(iface, nil, disableServerRoutes, flowLogger) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger, mtu) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } @@ -142,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func parseCreateEnv() (bool, bool) { - var disableConntrack, enableLocalForwarding bool +func parseCreateEnv() (bool, bool, bool) { + var disableConntrack, enableLocalForwarding, disableMSSClamping bool var err error if val := os.Getenv(EnvDisableConntrack); val != "" { disableConntrack, err = strconv.ParseBool(val) @@ -162,12 +190,18 @@ func parseCreateEnv() (bool, bool) { log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) } } + if val := os.Getenv(EnvDisableMSSClamping); val != "" { + disableMSSClamping, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err) + } + } - return disableConntrack, enableLocalForwarding + return disableConntrack, enableLocalForwarding, disableMSSClamping } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - disableConntrack, enableLocalForwarding := parseCreateEnv() +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -196,13 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, + netstackServices: make(map[serviceKey]struct{}), + mtu: mtu, } m.routingEnabled.Store(false) + if !disableMSSClamping { + m.mssClampEnabled = true + m.mssClampValue = mtu - ipTCPHeaderMinSize + } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) } - if disableConntrack { log.Info("conntrack is disabled") } else { @@ -210,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } - - // netstack needs the forwarder for local traffic if m.netstack && m.localForwarding { if err := m.initForwarder(); err != nil { log.Errorf("failed to initialize forwarder: %v", err) } } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } @@ -320,7 +357,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu) if err != nil { m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) @@ -626,11 +663,20 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return false } - if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { - return true + switch d.decoded[1] { + case layers.LayerTypeUDP: + if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { + return true + } + case layers.LayerTypeTCP: + // Clamp MSS on all TCP SYN packets, including those from local IPs. + // SNATed routed traffic may appear as local IP but still requires clamping. + if m.mssClampEnabled { + m.clampTCPMSS(packetData, d) + } } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation. +// Both sides advertise their MSS during connection establishment, so we need to clamp both. +func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { + if !d.tcp.SYN { + return false + } + if len(d.tcp.Options) == 0 { + return false + } + + mssOptionIndex := -1 + var currentMSS uint16 + for i, opt := range d.tcp.Options { + if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { + currentMSS = binary.BigEndian.Uint16(opt.OptionData) + if currentMSS > m.mssClampValue { + mssOptionIndex = i + break + } + } + } + + if mssOptionIndex == -1 { + return false + } + + ipHeaderSize := int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + + if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + return false + } + + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + return true +} + +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { + tcpHeaderStart := ipHeaderSize + tcpOptionsStart := tcpHeaderStart + 20 + + optOffset := tcpOptionsStart + for j := 0; j < mssOptionIndex; j++ { + switch d.tcp.Options[j].OptionType { + case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: + optOffset++ + default: + optOffset += 2 + len(d.tcp.Options[j].OptionData) + } + } + + mssValueOffset := optOffset + 2 + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + + m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) + return true +} + +func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) { + tcpLayer := packetData[tcpHeaderStart:] + tcpLength := len(packetData) - tcpHeaderStart + + tcpLayer[16] = 0 + tcpLayer[17] = 0 + + var pseudoSum uint32 + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + + var sum uint32 = pseudoSum + for i := 0; i < tcpLength-1; i += 2 { + sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) + } + if tcpLength%2 == 1 { + sum += uint32(tcpLayer[tcpLength-1]) << 8 + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + checksum := ^uint16(sum) + binary.BigEndian.PutUint16(tcpLayer[16:18], checksum) +} + +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +910,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) @@ -807,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet return true } - // If requested we pass local traffic to internal interfaces to the forwarder. - // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. - if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { + if m.shouldForward(d, dstIP) { return m.handleForwardedLocalTraffic(packetData) } @@ -1243,3 +1402,86 @@ func (m *Manager) DisableRouting() error { return nil } + +// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port +func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + m.netstackServices[key] = struct{}{} + m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType) + m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices)) +} + +// UnregisterNetstackService removes a service from the netstack registry +func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + delete(m.netstackServices, key) + m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port) +} + +// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use +func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { + switch protocol { + case nftypes.TCP: + return layers.LayerTypeTCP + case nftypes.UDP: + return layers.LayerTypeUDP + case nftypes.ICMP: + return layers.LayerTypeICMPv4 + default: + return gopacket.LayerType(0) // Invalid/unknown + } +} + +// shouldForward determines if a packet should be forwarded to the forwarder. +// The forwarder handles routing packets to the native OS network stack. +// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly. +func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { + // not enabled, never forward + if !m.localForwarding { + return false + } + + // netstack always needs to forward because it's lacking a native interface + // exception for registered netstack services, those should go to netstack listeners + if m.netstack { + return !m.hasMatchingNetstackService(d) + } + + // traffic to our other local interfaces (not NetBird IP) - always forward + if dstIP != m.wgIface.Address().IP { + return true + } + + // traffic to our NetBird IP, not netstack mode - send to netstack listeners + return false +} + +// hasMatchingNetstackService checks if there's a registered netstack service for this packet +func (m *Manager) hasMatchingNetstackService(d *decoder) bool { + if len(d.decoded) < 2 { + return false + } + + var dstPort uint16 + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + dstPort = uint16(d.udp.DstPort) + default: + return false + } + + key := serviceKey{protocol: d.decoded[1], port: dstPort} + m.netstackServiceMutex.RLock() + _, exists := m.netstackServices[key] + m.netstackServiceMutex.RUnlock() + + return exists +} diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..5a2d0410f 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -17,6 +17,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } -// generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { - b.Helper() - - ipv4 := &layers.IPv4{ - TTL: 64, - Version: 4, - SrcIP: srcIP, - DstIP: dstIP, - Protocol: layers.IPProtocolTCP, - } - - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - - // Set TCP flags - tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 - tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 - tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 - tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 - tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - - require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) - return buf.Bytes() -} - func BenchmarkRouteACLs(b *testing.B) { manager := setupRoutedManager(b, "10.10.0.100/16") @@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) { } } } + +// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound. +// This shows the overhead difference between the common case (non-SYN packets, fast path) +// and the rare case (SYN packets that need clamping, expensive path). +func BenchmarkMSSClamping(b *testing.B) { + scenarios := []struct { + name string + description string + genPacket func(*testing.B, net.IP, net.IP) []byte + frequency string + }{ + { + name: "syn_needs_clamp", + description: "SYN packet needing MSS clamping", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + frequency: "~0.1% of traffic - EXPENSIVE", + }, + { + name: "syn_no_clamp_needed", + description: "SYN packet with already-small MSS", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200) + }, + frequency: "~0.05% of traffic", + }, + { + name: "tcp_ack", + description: "Non-SYN TCP packet (ACK, data transfer)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + frequency: "~60-70% of traffic - FAST PATH", + }, + { + name: "tcp_psh_ack", + description: "TCP data packet (PSH+ACK)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + frequency: "~10-20% of traffic - FAST PATH", + }, + { + name: "udp", + description: "UDP packet", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + frequency: "~20-30% of traffic - FAST PATH", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled +// for the common case (non-SYN TCP packets). +func BenchmarkMSSClampingOverhead(b *testing.B) { + scenarios := []struct { + name string + enabled bool + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "disabled_tcp_ack", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "enabled_tcp_ack", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "disabled_syn_needs_clamp", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "enabled_syn_needs_clamp", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = sc.enabled + if sc.enabled { + manager.mssClampValue = 1240 + } + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases +func BenchmarkMSSClampingMemory(b *testing.B) { + scenarios := []struct { + name string + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "tcp_ack_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "syn_needs_clamp", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "udp_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte { + b.Helper() + + ip := &layers.IPv4{ + Version: 4, + IHL: 5, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Seq: 1000, + Window: 65535, + } + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ip)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{}))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 73f3face8..eb5aa3343 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -12,6 +12,7 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) require.NotNil(t, manager) @@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(tb, err) require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) @@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index bac06814d..120a9f418 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -17,9 +18,11 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nbiface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -66,7 +69,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -86,7 +89,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -119,7 +122,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -215,7 +218,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -265,7 +268,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -304,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -367,7 +370,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false, flowLogger) + manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -413,7 +416,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() @@ -495,7 +498,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) time.Sleep(time.Second) @@ -522,7 +525,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() // Close the existing tracker @@ -729,7 +732,7 @@ func TestUpdateSetMerge(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -815,7 +818,7 @@ func TestUpdateSetDeduplication(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -923,3 +926,327 @@ func TestUpdateSetDeduplication(t *testing.T) { require.Equal(t, tc.expected, isAllowed, tc.desc) } } + +func TestMSSClamping(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, 1280) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") + expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) + require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + srcIP := net.ParseIP("100.10.0.2") + dstIP := net.ParseIP("8.8.8.8") + + t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + }) + + t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { + lowMSS := uint16(1200) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified") + }) + + t.Run("SYN-ACK packet gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + }) + + t.Run("Non-SYN packet unchanged", func(t *testing.T) { + packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck)) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Empty(t, d.tcp.Options, "ACK packet should have no options") + }) +} + +func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Window: 65535, + } + + if flags&uint16(conntrack.TCPSyn) != 0 { + tcpLayer.SYN = true + } + if flags&uint16(conntrack.TCPAck) != 0 { + tcpLayer.ACK = true + } + if flags&uint16(conntrack.TCPFin) != 0 { + tcpLayer.FIN = true + } + if flags&uint16(conntrack.TCPRst) != 0 { + tcpLayer.RST = true + } + if flags&uint16(conntrack.TCPPush) != 0 { + tcpLayer.PSH = true + } + + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func TestShouldForward(t *testing.T) { + // Set up test addresses + wgIP := netip.MustParseAddr("100.10.0.1") + otherIP := netip.MustParseAddr("100.10.0.2") + + // Create test manager with mock interface + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + // Set the mock to return our test WG IP + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)} + } + + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Helper to create decoder with TCP packet + createTCPDecoder := func(dstPort uint16) *decoder { + ipv4 := &layers.IPv4{ + Version: 4, + Protocol: layers.IPProtocolTCP, + SrcIP: net.ParseIP("192.168.1.100"), + DstIP: wgIP.AsSlice(), + } + tcp := &layers.TCP{ + SrcPort: 54321, + DstPort: layers.TCPPort(dstPort), + } + + err := tcp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")) + require.NoError(t, err) + + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + + err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + require.NoError(t, err) + + return d + } + + tests := []struct { + name string + localForwarding bool + netstack bool + dstIP netip.Addr + serviceRegistered bool + servicePort uint16 + expected bool + description string + }{ + { + name: "no local forwarding", + localForwarding: false, + netstack: true, + dstIP: wgIP, + expected: false, + description: "should never forward when local forwarding disabled", + }, + { + name: "traffic to other local interface", + localForwarding: true, + netstack: false, + dstIP: otherIP, + expected: true, + description: "should forward traffic to our other local interfaces (not NetBird IP)", + }, + { + name: "traffic to NetBird IP, no netstack", + localForwarding: true, + netstack: false, + dstIP: wgIP, + expected: false, + description: "should send to netstack listeners (final return false path)", + }, + { + name: "traffic to our IP, netstack mode, no service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + expected: true, + description: "should forward when in netstack mode with no matching service", + }, + { + name: "traffic to our IP, netstack mode, with service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + serviceRegistered: true, + servicePort: 22, + expected: false, + description: "should send to netstack listeners when service is registered", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Configure manager + manager.localForwarding = tt.localForwarding + manager.netstack = tt.netstack + + // Register service if needed + if tt.serviceRegistered { + manager.RegisterNetstackService(nftypes.TCP, tt.servicePort) + defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort) + } + + // Create decoder for the test + decoder := createTCPDecoder(tt.servicePort) + if !tt.serviceRegistered { + decoder = createTCPDecoder(8080) // Use non-registered port + } + + // Test the method + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 42a3e0800..00cb3f1df 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -45,7 +45,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow HandleLocal: false, }) - mtu, err := iface.GetDevice().MTU() - if err != nil { - return nil, fmt.Errorf("get MTU: %w", err) - } nicID := tcpip.NICID(1) endpoint := &endpoint{ logger: logger, @@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } if err := s.CreateNIC(nicID, endpoint); err != nil { - return nil, fmt.Errorf("failed to create NIC: %v", err) + return nil, fmt.Errorf("create NIC: %v", err) } protoAddr := tcpip.ProtocolAddress{ diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d146de5e4..55743d975 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -49,7 +49,7 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { +func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ logger: logger, diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d2599e577 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_stateful_test.go b/client/firewall/uspfilter/nat_stateful_test.go new file mode 100644 index 000000000..21c6da06e --- /dev/null +++ b/client/firewall/uspfilter/nat_stateful_test.go @@ -0,0 +1,85 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// TestPortDNATBasic tests basic port DNAT functionality +func TestPortDNATBasic(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rule for peer B (the target) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Scenario: Peer A connects to Peer B on port 22 (should get NAT) + packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d := parsePacket(t, packetAtoB) + translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB) + require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)") + + // Verify port was translated to 22022 + d = parsePacket(t, packetAtoB) + require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022") + + // Scenario: Return traffic from Peer B to Peer A should NOT be translated + // (prevents double NAT - original port stored in conntrack) + returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321) + d2 := parsePacket(t, returnPacket) + translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA) + require.False(t, translatedReturn, "Return traffic from same IP should not be translated") +} + +// TestPortDNATMultipleRules tests multiple port DNAT rules +func TestPortDNATMultipleRules(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rules for both peers + err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Test traffic to peer B gets translated + packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d1 := parsePacket(t, packetToB) + translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB) + require.True(t, translatedToB, "Traffic to peer B should be translated") + d1 = parsePacket(t, packetToB) + require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022") + + // Test traffic to peer A gets translated + packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) + d2 := parsePacket(t, packetToA) + translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA) + require.True(t, translatedToA, "Traffic to peer A should be translated") + d2 = parsePacket(t, packetToA) + require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022") +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..400d61020 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,8 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -15,7 +17,7 @@ import ( func TestDNATTranslationCorrectness(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -99,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { func TestDNATMappingManagement(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -143,3 +145,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..d9f9f1aa8 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -10,6 +10,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) if !statefulMode { @@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..54966b50e 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "runtime" "time" @@ -57,8 +58,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone }), ) if err != nil { - log.Printf("DialContext error: %v", err) - return nil, err + return nil, fmt.Errorf("dial context: %w", err) } return conn, nil diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..198343fbd 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -3,6 +3,7 @@ package device import ( + "fmt" "strings" log "github.com/sirupsen/logrus" @@ -19,11 +20,12 @@ import ( // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform type WGTunDevice struct { - address wgaddr.Address - port int - key string - mtu uint16 - iceBind *bind.ICEBind + address wgaddr.Address + port int + key string + mtu uint16 + iceBind *bind.ICEBind + // todo: review if we can eliminate the TunAdapter tunAdapter TunAdapter disableDNS bool @@ -32,17 +34,19 @@ type WGTunDevice struct { filteredDevice *FilteredDevice udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer + renewableTun *RenewableTUN } func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ - address: address, - port: port, - key: key, - mtu: mtu, - iceBind: iceBind, - tunAdapter: tunAdapter, - disableDNS: disableDNS, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: iceBind, + tunAdapter: tunAdapter, + disableDNS: disableDNS, + renewableTun: NewRenewableTUN(), } } @@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } - tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd) + unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd) if err != nil { _ = unix.Close(fd) log.Errorf("failed to create Android interface: %s", err) return nil, err } + + t.renewableTun.AddDevice(unmonitoredTUN) + t.name = name - t.filteredDevice = newDeviceFilter(tunDevice) + t.filteredDevice = newDeviceFilter(t.renewableTun) log.Debugf("attaching to interface %v", name) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) @@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } +func (t *WGTunDevice) RenewTun(fd int) error { + if t.device == nil { + return fmt.Errorf("device not initialized") + } + + unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd) + if err != nil { + _ = unix.Close(fd) + log.Errorf("failed to renew Android interface: %s", err) + return err + } + + t.renewableTun.AddDevice(unmonitoredTUN) + + return nil +} + func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { // todo implement return nil @@ -150,6 +174,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go index 45ae8ba7d..f1a77d40a 100644 --- a/client/iface/device/device_netstack_android.go +++ b/client/iface/device/device_netstack_android.go @@ -2,6 +2,13 @@ package device +import "fmt" + func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { return t.create() } + +func (t *TunNetstackDevice) RenewTun(fd int) error { + // Doesn't make sense in Android for Netstack. + return fmt.Errorf("this function has not been implemented in Netstack for Android") +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device/renewable_tun.go b/client/iface/device/renewable_tun.go new file mode 100644 index 000000000..a501eebbb --- /dev/null +++ b/client/iface/device/renewable_tun.go @@ -0,0 +1,309 @@ +//go:build android + +package device + +import ( + "io" + "os" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" +) + +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +// +// It also redirects tun.Device's Events() to a separate goroutine +// and closes it when Close is called. +// +// The WaitGroup and CloseOnce fields are used to ensure that the +// goroutine is awaited and closed only once. +type closeAwareDevice struct { + isClosed atomic.Bool + tun.Device + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newClosableDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel (RenewableTUN's events channel). +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() +} + +type RenewableTUN struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + events chan tun.Event + closed atomic.Bool +} + +func NewRenewableTUN() *RenewableTUN { + r := &RenewableTUN{ + devices: make([]*closeAwareDevice, 0), + mu: sync.Mutex{}, + events: make(chan tun.Event, 16), + } + r.cond = sync.NewCond(&r.mu) + return r +} + +func (r *RenewableTUN) File() *os.File { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// Read reads from an underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries reading from the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + for { + dev := r.peekLast() + if dev == nil { + // wait until AddDevice() signals a new device via cond.Broadcast() + if !r.waitForDevice() { // returns false if the renewable TUN itself is closed + return 0, io.EOF + } + continue + } + + n, err = dev.Read(bufs, sizes, offset) + if err == nil { + return n, nil + } + + // swap in progress; retry on the newest instead of returning the error + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return n, err // propagate non-swap error + } +} + +// Write writes to underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries writing to the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} + +func (r *RenewableTUN) MTU() (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return 0, err + } +} + +func (r *RenewableTUN) Name() (string, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return "", io.EOF + } + continue + } + name, err := dev.Name() + if err == nil { + return name, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return "", err + } +} + +// Events returns a channel that is fed events from the underlying tun.Device's events channel +// once it is added. +func (r *RenewableTUN) Events() <-chan tun.Event { + return r.events +} + +func (r *RenewableTUN) Close() error { + // Attempts to set the RenewableTUN closed flag to true. + // If it's already true, returns immediately. + if !r.closed.CompareAndSwap(false, true) { + return nil // already closed: idempotent + } + r.mu.Lock() + devices := r.devices + r.devices = nil + r.cond.Broadcast() + r.mu.Unlock() + + var lastErr error + + log.Debugf("closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + log.Debugf("error closing a device: %v", err) + lastErr = err + } + } + + close(r.events) + return lastErr +} + +func (r *RenewableTUN) BatchSize() int { + return 1 +} + +func (r *RenewableTUN) AddDevice(device tun.Device) { + r.mu.Lock() + if r.closed.Load() { + r.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(r.devices) > 0 { + toClose = r.devices[len(r.devices)-1] + } + + cad := newClosableDevice(device) + cad.redirectEvents(r.events) + + r.devices = []*closeAwareDevice{cad} + r.cond.Broadcast() + + r.mu.Unlock() + + if toClose != nil { + if err := toClose.Close(); err != nil { + log.Debugf("error closing last device: %v", err) + } + } +} + +func (r *RenewableTUN) waitForDevice() bool { + r.mu.Lock() + defer r.mu.Unlock() + + for len(r.devices) == 0 && !r.closed.Load() { + r.cond.Wait() + } + return !r.closed.Load() +} + +func (r *RenewableTUN) peekLast() *closeAwareDevice { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.devices) == 0 { + return nil + } + + return r.devices[len(r.devices)-1] +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..3899bf426 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,6 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + RenewTun(fd int) error + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 158672160..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..13ae9393c 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -24,3 +24,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on non mobile") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on non-android") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..d2d9eb70e 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -6,6 +6,7 @@ import ( // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. +// todo: review does this function really necessary or can we merge it with iOS func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { w.mu.Lock() defer w.mu.Unlock() @@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s func (w *WGIface) Create() error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.RenewTun(fd) +} diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..0b7cd36ef 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -39,3 +39,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on this platform") +} diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index e890b30f3..6bbfeaa63 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -1,6 +1,7 @@ package iface import ( + "context" "fmt" "net" "net/netip" @@ -9,13 +10,13 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/stdnet" ) // keep darwin compatibility @@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) { func Test_CreateInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) wgIP := "10.99.99.1/32" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -166,7 +167,7 @@ func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) { func Test_UpdatePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.9/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.13/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) { peer2wgPort := 33200 keepAlive := 1 * time.Second - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) { guid = fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - newNet, err = stdnet.NewNet() + newNet, err = stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/iface/udpmux/mux.go b/client/iface/udpmux/mux.go index 319724926..c5d2de4a5 100644 --- a/client/iface/udpmux/mux.go +++ b/client/iface/udpmux/mux.go @@ -1,6 +1,7 @@ package udpmux import ( + "context" "fmt" "io" "net" @@ -12,8 +13,9 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) /* @@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() { if len(networks) > 0 { if m.params.Net == nil { var err error - if m.params.Net, err = stdnet.NewNet(); err != nil { + if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { m.params.Logger.Errorf("failed to get create network: %v", err) } } diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..dd6f9479a 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -17,7 +17,6 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" - "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -29,11 +28,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,30 +80,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) - - enableSSH := networkMap.PeerConfig != nil && - networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). - if enableSSH { - rules = append(rules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: strconv.Itoa(ssh.DefaultSSHPort), - }) - } + rules := networkMap.FirewallRules // if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag // we have old version of management without rules handling, we should allow all traffic @@ -368,145 +339,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..4bc0fd800 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" @@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -188,492 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string @@ -757,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) { }) } } - -func TestDefaultManagerEnableSSHRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - PeerConfig: &mgmProto.PeerConfig{ - SshConfig: &mgmProto.SSHConfig{ - SshEnabled: true, - }, - }, - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ifaceMock := mocks.NewMockIFaceMapper(ctrl) - ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() - ifaceMock.EXPECT().SetFilter(gomock.Any()) - network := netip.MustParsePrefix("172.0.0.1/32") - - ifaceMock.EXPECT().Name().Return("lo").AnyTimes() - ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: network.Addr(), - Network: network, - }).AnyTimes() - ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) - require.NoError(t, err) - defer func() { - err = fw.Close(nil) - require.NoError(t, err) - }() - - acl := NewDefaultManager(fw) - - acl.ApplyFiltering(networkMap, false) - - expectedRules := 3 - if fw.IsStateful() { - expectedRules = 3 // 2 inbound rules + SSH rule - } - assert.Equal(t, expectedRules, len(acl.peerRulesPairs)) -} diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index da4f16c8d..8ca760742 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow deviceCode.VerificationURIComplete = deviceCode.VerificationURI } + if d.providerConfig.LoginHint != "" { + deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint) + if deviceCode.VerificationURI != "" { + deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint) + } + } + return deviceCode, err } +func appendLoginHint(uri, loginHint string) string { + if uri == "" || loginHint == "" { + return uri + } + + parsedURL, err := url.Parse(uri) + if err != nil { + log.Debugf("failed to parse verification URI for login_hint: %v", err) + return uri + } + + query := parsedURL.Query() + query.Set("login_hint", loginHint) + parsedURL.RawQuery = query.Encode() + + return parsedURL.String() +} + func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", d.providerConfig.ClientID) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 4458f600c..85a166005 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -60,38 +60,45 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } +func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool { + return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient +} + // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration // // It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { - if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { - return authenticateWithDeviceCodeFlow(ctx, config) +// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV) +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) { + if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) { + return authenticateWithDeviceCodeFlow(ctx, config, hint) } - pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) if err != nil { - // fallback to device code flow log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debug("falling back to device code flow") - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } return pkceFlow, nil } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } + + pkceFlowInfo.ProviderConfig.LoginHint = hint + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { @@ -107,5 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } + deviceFlowInfo.ProviderConfig.LoginHint = hint + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 8741e8636..cc43c8648 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" @@ -21,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" + "github.com/netbirdio/netbird/shared/management/client/common" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct { func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string - // find the first available redirect URL + excludedRanges := getSystemExcludedPortRanges() + for _, redirectURL := range config.RedirectURLs { - if !isRedirectURLPortUsed(redirectURL) { + if !isRedirectURLPortUsed(redirectURL, excludedRanges) { availableRedirectURL = redirectURL break } @@ -102,13 +105,16 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), } if !p.providerConfig.DisablePromptLogin { - if p.providerConfig.LoginFlag.IsPromptLogin() { + switch p.providerConfig.LoginFlag { + case common.LoginFlagPromptLogin: params = append(params, oauth2.SetAuthURLParam("prompt", "login")) - } - if p.providerConfig.LoginFlag.IsMaxAge0Login() { + case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) } } + if p.providerConfig.LoginHint != "" { + params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint)) + } authURL := p.oAuthConfig.AuthCodeURL(state, params...) @@ -189,17 +195,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, if authError := query.Get(queryError); authError != "" { authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + if authErrorDesc != "" { + return nil, fmt.Errorf("authentication failed: %s", authErrorDesc) + } + return nil, fmt.Errorf("authentication failed: %s", authError) } // Prevent timing attacks on the state if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") + return nil, fmt.Errorf("authentication failed: Invalid state") } code := query.Get(queryCode) if code == "" { - return nil, fmt.Errorf("missing code") + return nil, fmt.Errorf("authentication failed: missing code") } return p.oAuthConfig.Exchange( @@ -228,7 +237,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, } if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err) } email, err := parseEmailFromIDToken(tokenInfo.IDToken) @@ -276,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string { return base64.RawURLEncoding.EncodeToString(sha2[:]) } -// isRedirectURLPortUsed checks if the port used in the redirect URL is in use. -func isRedirectURLPortUsed(redirectURL string) bool { +// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows. +func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool { parsedURL, err := url.Parse(redirectURL) if err != nil { log.Errorf("failed to parse redirect URL: %v", err) return true } - addr := fmt.Sprintf(":%s", parsedURL.Port()) + port := parsedURL.Port() + + if isPortInExcludedRange(port, excludedRanges) { + log.Warnf("port %s is in Windows excluded port range, skipping", port) + return true + } + + addr := fmt.Sprintf(":%s", port) conn, err := net.DialTimeout("tcp", addr, 3*time.Second) if err != nil { return false @@ -298,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool { return true } +// excludedPortRange represents a range of excluded ports. +type excludedPortRange struct { + start int + end int +} + +// isPortInExcludedRange checks if the given port is in any of the excluded ranges. +func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool { + if len(excludedRanges) == 0 { + return false + } + + portNum, err := strconv.Atoi(port) + if err != nil { + log.Debugf("invalid port number %s: %v", port, err) + return false + } + + for _, r := range excludedRanges { + if portNum >= r.start && portNum <= r.end { + return true + } + } + + return false +} + func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) { tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl) if err != nil { diff --git a/client/internal/auth/pkce_flow_other.go b/client/internal/auth/pkce_flow_other.go new file mode 100644 index 000000000..96df41539 --- /dev/null +++ b/client/internal/auth/pkce_flow_other.go @@ -0,0 +1,8 @@ +//go:build !windows + +package auth + +// getSystemExcludedPortRanges returns nil on non-Windows platforms. +func getSystemExcludedPortRanges() []excludedPortRange { + return nil +} diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b2347d12d..b77a17eaa 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -2,8 +2,11 @@ package auth import ( "context" + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal" @@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) { name string loginFlag mgm.LoginFlag disablePromptLogin bool - expect string + expectContains []string }{ { - name: "Prompt login", - loginFlag: mgm.LoginFlagPrompt, - expect: promptLogin, + name: "Prompt login", + loginFlag: mgm.LoginFlagPromptLogin, + expectContains: []string{promptLogin}, }, { - name: "Max age 0 login", - loginFlag: mgm.LoginFlagMaxAge0, - expect: maxAge0, + name: "Max age 0", + loginFlag: mgm.LoginFlagMaxAge0, + expectContains: []string{maxAge0}, }, { name: "Disable prompt login", - loginFlag: mgm.LoginFlagPrompt, + loginFlag: mgm.LoginFlagPromptLogin, disablePromptLogin: true, + expectContains: []string{}, + }, + { + name: "None flag should not add parameters", + loginFlag: mgm.LoginFlagNone, + expectContains: []string{}, }, } @@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) { RedirectURLs: []string{"http://127.0.0.1:33992/"}, UseIDToken: true, LoginFlag: tc.loginFlag, + DisablePromptLogin: tc.disablePromptLogin, } pkce, err := NewPKCEAuthorizationFlow(config) if err != nil { @@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) { t.Fatalf("Failed to request auth info: %v", err) } - if !tc.disablePromptLogin { - require.Contains(t, authInfo.VerificationURIComplete, tc.expect) - } else { - require.Contains(t, authInfo.VerificationURIComplete, promptLogin) - require.NotContains(t, authInfo.VerificationURIComplete, maxAge0) + for _, expected := range tc.expectContains { + require.Contains(t, authInfo.VerificationURIComplete, expected) } }) } } + +func TestIsPortInExcludedRange(t *testing.T) { + tests := []struct { + name string + port string + excludedRanges []excludedPortRange + expectedBlocked bool + }{ + { + name: "Port in excluded range", + port: "8080", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at start of range", + port: "8000", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at end of range", + port: "8100", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port before range", + port: "7999", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Port after range", + port: "8101", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty excluded ranges", + port: "8080", + excludedRanges: []excludedPortRange{}, + expectedBlocked: false, + }, + { + name: "Nil excluded ranges", + port: "8080", + excludedRanges: nil, + expectedBlocked: false, + }, + { + name: "Multiple ranges - port in second range", + port: "9050", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: true, + }, + { + name: "Multiple ranges - port not in any range", + port: "8500", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: false, + }, + { + name: "Invalid port string", + port: "invalid", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty port string", + port: "", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPortInExcludedRange(tt.port, tt.excludedRanges) + assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch") + }) + } +} + +func TestIsRedirectURLPortUsed(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + usedPort := listener.Addr().(*net.TCPAddr).Port + + tests := []struct { + name string + redirectURL string + excludedRanges []excludedPortRange + expectedUsed bool + }{ + { + name: "Port in excluded range", + redirectURL: "http://127.0.0.1:8080/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + { + name: "Port actually in use", + redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort), + excludedRanges: nil, + expectedUsed: true, + }, + { + name: "Port not in use and not excluded", + redirectURL: "http://127.0.0.1:65432/", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Invalid URL without port", + redirectURL: "not-a-valid-url", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Port excluded even if not in use", + redirectURL: "http://127.0.0.1:8050/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges) + assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch") + }) + } +} diff --git a/client/internal/auth/pkce_flow_windows.go b/client/internal/auth/pkce_flow_windows.go new file mode 100644 index 000000000..cf3f8718f --- /dev/null +++ b/client/internal/auth/pkce_flow_windows.go @@ -0,0 +1,86 @@ +//go:build windows + +package auth + +import ( + "bufio" + "fmt" + "os/exec" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh. +func getSystemExcludedPortRanges() []excludedPortRange { + ranges, err := getExcludedPortRangesFromNetsh() + if err != nil { + log.Debugf("failed to get Windows excluded port ranges: %v", err) + return nil + } + + return ranges +} + +// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command. +func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) { + cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("netsh command: %w", err) + } + + return parseExcludedPortRanges(string(output)) +} + +// parseExcludedPortRanges parses the output of the netsh command to extract port ranges. +func parseExcludedPortRanges(output string) ([]excludedPortRange, error) { + var ranges []excludedPortRange + scanner := bufio.NewScanner(strings.NewReader(output)) + + foundHeader := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") { + foundHeader = true + continue + } + + if !foundHeader { + continue + } + + if strings.Contains(line, "----------") { + continue + } + + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + startPort, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + endPort, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + + ranges = append(ranges, excludedPortRange{start: startPort, end: endPort}) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan output: %w", err) + } + + return ranges, nil +} diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go new file mode 100644 index 000000000..dd455b2fe --- /dev/null +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -0,0 +1,116 @@ +//go:build windows + +package auth + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestParseExcludedPortRanges(t *testing.T) { + tests := []struct { + name string + netshOutput string + expectedRanges []excludedPortRange + expectError bool + }{ + { + name: "Valid netsh output with multiple ranges", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 + +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 5357 5357 * + 50000 50059 * +`, + expectedRanges: []excludedPortRange{ + {start: 5357, end: 5357}, + {start: 50000, end: 50059}, + }, + expectError: false, + }, + { + name: "Empty output", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 +`, + expectedRanges: nil, + expectError: false, + }, + { + name: "Single range", + netshOutput: ` +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 8080 8090 +`, + expectedRanges: []excludedPortRange{ + {start: 8080, end: 8090}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseExcludedPortRanges(tt.netshOutput) + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedRanges, ranges) + } + }) + } +} + +func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { + ranges := getSystemExcludedPortRanges() + t.Logf("Found %d excluded port ranges on this system", len(ranges)) + + listener1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener1.Close() + }() + usedPort1 := listener1.Addr().(*net.TCPAddr).Port + + availablePort := 65432 + + config := internal.PKCEAuthProviderConfig{ + ClientID: "test-client-id", + Audience: "test-audience", + TokenEndpoint: "https://test-token-endpoint.com/token", + Scope: "openid email profile", + AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", + RedirectURLs: []string{ + fmt.Sprintf("http://127.0.0.1:%d/", usedPort1), + fmt.Sprintf("http://127.0.0.1:%d/", availablePort), + }, + UseIDToken: true, + } + + flow, err := NewPKCEAuthorizationFlow(config) + require.NoError(t, err) + require.NotNil(t, flow) + assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort), + "Should skip port in use and select available port") +} diff --git a/client/internal/connect.go b/client/internal/connect.go index c9331baf5..017c8bf10 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -24,9 +24,14 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/shared/management/client" mgmProto "github.com/netbirdio/netbird/shared/management/proto" @@ -34,16 +39,17 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - engine *Engine - engineMutex sync.Mutex + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status + doInitialAutoUpdate bool + + engine *Engine + engineMutex sync.Mutex persistSyncResponse bool } @@ -52,13 +58,15 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, + doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + doInitialAutoUpdate: doInitalAutoUpdate, + engineMutex: sync.Mutex{}, } } @@ -74,6 +82,7 @@ func (c *ConnectClient) RunOnAndroid( networkChangeListener listener.NetworkChangeListener, dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, + stateFilePath string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -82,6 +91,7 @@ func (c *ConnectClient) RunOnAndroid( NetworkChangeListener: networkChangeListener, HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil) } @@ -160,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return err } + var path string + if runtime.GOOS == "ios" || runtime.GOOS == "android" { + // On mobile, use the provided state file path directly + if !fileExists(mobileDependency.StateFilePath) { + if err := createFile(mobileDependency.StateFilePath); err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + path = mobileDependency.StateFilePath + } else { + sm := profilemanager.NewServiceManager("") + path = sm.GetStatePath() + } + stateManager := statemanager.New(path) + stateManager.RegisterState(&sshconfig.ShutdownState{}) + + updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) + if err == nil { + updateManager.CheckUpdateSuccess(c.ctx) + + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) + } + } + defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle @@ -271,15 +308,25 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) - c.engine.SetSyncResponsePersistence(c.persistSyncResponse) + engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) + engine.SetSyncResponsePersistence(c.persistSyncResponse) + c.engine = engine c.engineMutex.Unlock() - if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { + if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } + if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { + // AutoUpdate will be true when the user click on "Connect" menu on the UI + if c.doInitialAutoUpdate { + log.Infof("start engine by ui, run auto-update check") + c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) + c.doInitialAutoUpdate = false + } + } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) @@ -289,15 +336,20 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + c.engine = nil + c.engineMutex.Unlock() + + // todo: consider to remove this condition. Is not thread safe. + // We should always call Stop(), but we need to verify that it is idempotent + if engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -382,19 +434,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } @@ -420,20 +465,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf nm = *config.NetworkMonitor } engineConf := &EngineConfig{ - WgIfaceName: config.WgIface, - WgAddr: peerConfig.Address, - IFaceBlackList: config.IFaceBlackList, - DisableIPv6Discovery: config.DisableIPv6Discovery, - WgPrivateKey: key, - WgPort: config.WgPort, - NetworkMonitor: nm, - SSHKey: []byte(config.SSHKey), - NATExternalIPs: config.NATExternalIPs, - CustomDNSAddress: config.CustomDNSAddress, - RosenpassEnabled: config.RosenpassEnabled, - RosenpassPermissive: config.RosenpassPermissive, - ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), - DNSRouteInterval: config.DNSRouteInterval, + WgIfaceName: config.WgIface, + WgAddr: peerConfig.Address, + IFaceBlackList: config.IFaceBlackList, + DisableIPv6Discovery: config.DisableIPv6Discovery, + WgPrivateKey: key, + WgPort: config.WgPort, + NetworkMonitor: nm, + SSHKey: []byte(config.SSHKey), + NATExternalIPs: config.NATExternalIPs, + CustomDNSAddress: config.CustomDNSAddress, + RosenpassEnabled: config.RosenpassEnabled, + RosenpassPermissive: config.RosenpassPermissive, + ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), + EnableSSHRoot: config.EnableSSHRoot, + EnableSSHSFTP: config.EnableSSHSFTP, + EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding, + DisableSSHAuth: config.DisableSSHAuth, + DNSRouteInterval: config.DNSRouteInterval, DisableClientRoutes: config.DisableClientRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound, @@ -519,6 +569,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..01a0377a5 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -44,16 +45,19 @@ interfaces.txt: Anonymized network interface information, if --system-info flag ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. +scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. heap.prof: Heap profiling information (snapshot of memory allocations). allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. +stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. Anonymization Process @@ -107,6 +111,9 @@ go tool pprof -http=:8088 heap.prof This will open a web browser tab with the profiling information. +Stack Trace +The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created. + Routes The routes.txt file contains detailed routing table information in a tabular format: @@ -184,6 +191,20 @@ The ip_rules.txt file contains detailed IP routing rule information: The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. + +DNS Configuration +The debug bundle includes platform-specific DNS configuration files: + +resolv.conf (Unix systems): +- Contains DNS resolver configuration from /etc/resolv.conf +- Includes nameserver entries, search domains, and resolver options +- All IP addresses and domain names are anonymized following the same rules as other files + +scutil_dns.txt (macOS only): +- Contains detailed DNS configuration from scutil --dns +- Shows DNS configuration for all network interfaces +- Includes search domains, nameservers, and DNS resolver settings +- All IP addresses and domain names are anonymized ` const ( @@ -311,6 +332,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add profiles to debug bundle: %v", err) } + if err := g.addStackTrace(); err != nil { + log.Errorf("failed to add stack trace to debug bundle: %v", err) + } + if err := g.addSyncResponse(); err != nil { return fmt.Errorf("add sync response: %w", err) } @@ -338,6 +363,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add systemd logs: %v", err) } + if err := g.addUpdateLogs(); err != nil { + log.Errorf("failed to add updater logs: %v", err) + } + return nil } @@ -357,6 +386,10 @@ func (g *BundleGenerator) addSystemInfo() { if err := g.addFirewallRules(); err != nil { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + + if err := g.addDNSInfo(); err != nil { + log.Errorf("failed to add DNS info to debug bundle: %v", err) + } } func (g *BundleGenerator) addReadme() error { @@ -433,6 +466,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) if g.internalConfig.ServerSSHAllowed != nil { configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) } + if g.internalConfig.EnableSSHRoot != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot)) + } + if g.internalConfig.EnableSSHSFTP != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP)) + } + if g.internalConfig.EnableSSHLocalPortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding)) + } + if g.internalConfig.EnableSSHRemotePortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) + } configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) @@ -490,6 +535,18 @@ func (g *BundleGenerator) addProf() (err error) { return nil } +func (g *BundleGenerator) addStackTrace() error { + buf := make([]byte, 5242880) // 5 MB buffer + n := runtime.Stack(buf, true) + + stackTrace := bytes.NewReader(buf[:n]) + if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil { + return fmt.Errorf("add stack trace file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addInterfaces() error { interfaces, err := net.Interfaces() if err != nil { @@ -564,6 +621,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -596,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error { return nil } +func (g *BundleGenerator) addUpdateLogs() error { + inst := installer.New() + logFiles := inst.LogFiles() + if len(logFiles) == 0 { + return nil + } + + log.Infof("adding updater logs") + for _, logFile := range logFiles { + data, err := os.ReadFile(logFile) + if err != nil { + log.Warnf("failed to read update log file %s: %v", logFile, err) + continue + } + + baseName := filepath.Base(logFile) + if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil { + return fmt.Errorf("add update log file %s to zip: %w", baseName, err) + } + } + return nil +} + func (g *BundleGenerator) addCorruptedStateFiles() error { sm := profilemanager.NewServiceManager("") pattern := sm.GetStatePath() diff --git a/client/internal/debug/debug_darwin.go b/client/internal/debug/debug_darwin.go new file mode 100644 index 000000000..91e10214f --- /dev/null +++ b/client/internal/debug/debug_darwin.go @@ -0,0 +1,53 @@ +//go:build darwin && !ios + +package debug + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + if err := g.addScutilDNS(); err != nil { + log.Errorf("failed to add scutil DNS output: %v", err) + } + + return nil +} + +func (g *BundleGenerator) addScutilDNS() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "scutil", "--dns") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("execute scutil --dns: %w", err) + } + + if len(bytes.TrimSpace(output)) == 0 { + return fmt.Errorf("no scutil DNS output") + } + + content := string(output) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil { + return fmt.Errorf("add scutil DNS output to zip: %w", err) + } + + return nil +} diff --git a/client/internal/debug/debug_mobile.go b/client/internal/debug/debug_mobile.go index c00c65132..3c1745ff3 100644 --- a/client/internal/debug/debug_mobile.go +++ b/client/internal/debug/debug_mobile.go @@ -5,3 +5,7 @@ package debug func (g *BundleGenerator) addRoutes() error { return nil } + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_nondarwin.go b/client/internal/debug/debug_nondarwin.go new file mode 100644 index 000000000..dfc2eace5 --- /dev/null +++ b/client/internal/debug/debug_nondarwin.go @@ -0,0 +1,16 @@ +//go:build unix && !darwin && !android + +package debug + +import ( + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + return nil +} diff --git a/client/internal/debug/debug_nonunix.go b/client/internal/debug/debug_nonunix.go new file mode 100644 index 000000000..18d017050 --- /dev/null +++ b/client/internal/debug/debug_nonunix.go @@ -0,0 +1,7 @@ +//go:build !unix + +package debug + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_unix.go b/client/internal/debug/debug_unix.go new file mode 100644 index 000000000..7e8a74eb0 --- /dev/null +++ b/client/internal/debug/debug_unix.go @@ -0,0 +1,29 @@ +//go:build unix && !android + +package debug + +import ( + "fmt" + "os" + "strings" +) + +const resolvConfPath = "/etc/resolv.conf" + +func (g *BundleGenerator) addResolvConf() error { + data, err := os.ReadFile(resolvConfPath) + if err != nil { + return fmt.Errorf("read %s: %w", resolvConfPath, err) + } + + content := string(data) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil { + return fmt.Errorf("add resolv.conf to zip: %w", err) + } + + return nil +} diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 6bd29801d..7f7d06130 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct { Scope string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it diff --git a/client/internal/dns.go b/client/internal/dns.go index 5e604bec5..3c68e4d00 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple var records []nbdns.SimpleRecord for _, zone := range config.CustomZones { + if zone.SkipPTRProcess { + continue + } for _, record := range zone.Records { if record.Type != int(dns.TypeA) { continue @@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) { records := collectPTRRecords(config, network) reverseZone := nbdns.CustomZone{ - Domain: zoneName, - Records: records, + Domain: zoneName, + Records: records, + SearchDomainDisabled: true, } config.CustomZones = append(config.CustomZones, reverseZone) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index fa474afde..f7dc46a6b 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -11,11 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -const ( - ipv4ReverseZone = ".in-addr.arpa." - ipv6ReverseZone = ".ip6.arpa." -) - type hostManager interface { applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error @@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H } for _, customZone := range dnsConfig.CustomZones { - matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), - MatchOnly: matchOnly, + MatchOnly: customZone.SearchDomainDisabled, }) } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) @@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 290395473..d01be0c2c 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "net/url" "strings" "sync" @@ -26,6 +27,11 @@ type Resolver struct { mutex sync.RWMutex } +type ipsResponse struct { + ips []netip.Addr + err error +} + // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ @@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + ips, err := lookupIPWithExtraTimeout(ctx, d) if err != nil { - return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + return err } var aRecords, aaaaRecords []dns.RR @@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return nil } +func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { + log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) + defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) + resultChan := make(chan *ipsResponse, 1) + + go func() { + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + resultChan <- &ipsResponse{ + err: err, + ips: ips, + } + }() + + var resp *ipsResponse + + select { + case <-time.After(dnsTimeout + time.Millisecond*500): + log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) + return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) + case <-ctx.Done(): + return nil, ctx.Err() + case resp = <-resultChan: + } + + if resp.err != nil { + return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + } + return resp.ips, nil +} + // PopulateFromConfig extracts and caches domains from the client configuration. func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { if mgmtURL == nil { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..94945b55a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -79,6 +80,7 @@ type DefaultServer struct { updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + currentConfigHash uint64 handlerChain *HandlerChain extraDomains map[domain.Domain]int @@ -206,6 +208,7 @@ func newDefaultServer( hostsDNSHolder: newHostsDNSHolder(), hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, + currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied } // register with root zone, handler chain takes care of the routing @@ -318,6 +321,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +511,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } @@ -583,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() { log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains)) + hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err) + // Fall through to apply config anyway (fail-safe approach) + } else if s.currentConfigHash == hash { + log.Debugf("not applying host config as there are no changes") + return + } + + log.Debugf("applying host config as there are changes") if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) + return + } + + // Only update hash if it was computed successfully and config was applied + if err == nil { + s.currentConfigHash = hash } s.registerFallback(config) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..fe1f67f66 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { privKey, _ := wgtypes.GenerateKey() - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err @@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false, flowLogger) + pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err @@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) { "other.example.com.", "duplicate.example.com.", }, - applyHostConfigCall: 4, + // Expect 3 calls instead of 4 because when deregistering duplicate.example.com, + // the domain remains in the config (ref count goes from 2 to 1), so the host + // config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 3, }, { name: "Config update with new domains after registration", @@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) { expectedMatchOnly: []string{ "extra.example.com.", }, - applyHostConfigCall: 3, + // Expect 2 calls instead of 3 because when deregistering protected.example.com, + // it's removed from extraDomains but still remains in the config (from customZones), + // so the host config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 2, }, { name: "Register domain that is part of nameserver group", diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c19e0acb5..2a92fd6d8 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add timeoutMsg += " " + peerInfo } timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warnf(timeoutMsg) + logger.Warn(timeoutMsg) } func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go index c23f0f31d..44ebe290b 100644 --- a/client/internal/dnsfwd/cache_test.go +++ b/client/internal/dnsfwd/cache_test.go @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) { t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) } } - diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c..6b8042ccb 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() @@ -200,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + // Unmap IPv4-mapped IPv6 addresses that some resolvers may return + for i, ip := range ips { + ips[i] = ip.Unmap() + } + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) f.cache.set(domain, question.Qtype, ips) diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1..4d0b96a75 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,27 +4,34 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex +const ( + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) -const ( - dnsTTL = 60 //seconds -) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { @@ -36,24 +43,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + wgIface wgIface + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, + wgIface: wgIface, + serverPort: serverPort, } } @@ -67,13 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -98,6 +123,20 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -113,7 +152,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { @@ -122,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..55645b494 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -9,7 +9,6 @@ import ( "net/netip" "net/url" "os" - "reflect" "runtime" "slices" "sort" @@ -30,7 +29,6 @@ import ( firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" - nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" @@ -44,17 +42,16 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" - "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -75,6 +72,7 @@ const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms connInitLimit = 200 + disableAutoUpdate = "disabled" ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -115,7 +113,12 @@ type EngineConfig struct { RosenpassEnabled bool RosenpassPermissive bool - ServerSSHAllowed bool + ServerSSHAllowed bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool DNSRouteInterval time.Duration @@ -173,8 +176,7 @@ type Engine struct { networkMonitor *networkmonitor.NetworkMonitor - sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) - sshServer nbssh.Server + sshServer sshServer statusRecorder *peer.Status @@ -199,12 +201,16 @@ type Engine struct { connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager - // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + // auto-update + updateManager *updatemanager.Manager - // dns forwarder port - dnsFwdPort uint16 + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup + + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -218,17 +224,7 @@ type localIpUpdater interface { } // NewEngine creates a new Connection Engine with probes attached -func NewEngine( - clientCtx context.Context, - clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, - config *EngineConfig, - mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, -) *Engine { +func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -243,29 +239,13 @@ func NewEngine( STUNs: []*stun.URI{}, TURNs: []*stun.URI{}, networkSerial: 0, - sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, + stateManager: stateManager, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } - sm := profilemanager.NewServiceManager("") - - path := sm.GetStatePath() - if runtime.GOOS == "ios" { - if !fileExists(mobileDep.StateFilePath) { - err := createFile(mobileDep.StateFilePath) - if err != nil { - log.Errorf("failed to create state file: %v", err) - // we are not exiting as we can run without the state manager - } - } - - path = mobileDep.StateFilePath - } - engine.stateManager = statemanager.New(path) - log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) return engine } @@ -277,7 +257,6 @@ func (e *Engine) Stop() error { return nil } e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() if e.connMgr != nil { e.connMgr.Close() @@ -289,8 +268,11 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() + if err := e.stopSSHServer(); err != nil { + log.Warnf("failed to stop SSH server: %v", err) + } + + e.cleanupSSHConfig() if e.ingressGatewayMgr != nil { if err := e.ingressGatewayMgr.Close(); err != nil { @@ -299,37 +281,37 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } - if e.routeManager != nil { - e.routeManager.Stop(e.stateManager) - } - - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } + if e.updateManager != nil { + e.updateManager.Stop() + } + + log.Info("cleaning up status recorder states") e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } + if e.routeManager != nil { + e.routeManager.Stop(e.stateManager) + } + + e.stopDNSForwarder() + + // stop/restore DNS after peers are closed but before interface goes down + // so dbus and friends don't complain because of a missing interface + e.stopDNSServer() + if e.cancel != nil { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -337,24 +319,64 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") + stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer stateCancel() - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - if err := e.stateManager.Stop(ctx); err != nil { - return fmt.Errorf("failed to stop state manager: %w", err) + if err := e.stateManager.Stop(stateCtx); err != nil { + log.Errorf("failed to stop state manager: %v", err) } if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + e.syncMsgMux.Unlock() + + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -394,8 +416,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) if err != nil { return fmt.Errorf("create rosenpass manager: %w", err) } - err := e.rpManager.Run() - if err != nil { + if err := e.rpManager.Run(); err != nil { return fmt.Errorf("run rosenpass manager: %w", err) } } @@ -447,6 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) } if err := e.createFirewall(); err != nil { + e.close() return err } @@ -484,14 +506,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -500,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } +func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + e.handleAutoUpdateVersion(autoUpdateSettings, true) +} + func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -507,7 +536,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -671,14 +700,10 @@ func (e *Engine) removeAllPeers() error { return nil } -// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server +// removePeer closes an existing peer connection and removes a peer func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) - if !isNil(e.sshServer) { - e.sshServer.RemoveAuthorizedKey(peerKey) - } - e.connMgr.RemovePeerConn(peerKey) err := e.statusRecorder.RemovePeer(peerKey) @@ -712,10 +737,54 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { + if autoUpdateSettings == nil { + return + } + + disabled := autoUpdateSettings.Version == disableAutoUpdate + + // Stop and cleanup if disabled + if e.updateManager != nil && disabled { + log.Infof("auto-update is disabled, stopping update manager") + e.updateManager.Stop() + e.updateManager = nil + return + } + + // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup + if !autoUpdateSettings.AlwaysUpdate && !initialCheck { + log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") + return + } + + // Start manager if needed + if e.updateManager == nil { + log.Infof("starting auto-update manager") + updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) + if err != nil { + return + } + e.updateManager = updateManager + e.updateManager.Start(e.ctx) + } + log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version) +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + + if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + } + if update.GetNetbirdConfig() != nil { wCfg := update.GetNetbirdConfig() err := e.updateTURNs(wCfg.GetTurns()) @@ -850,6 +919,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) if err := e.mgmClient.SyncMeta(info); err != nil { @@ -859,65 +933,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { return nil } -func isNil(server nbssh.Server) bool { - return server == nil || reflect.ValueOf(server).IsNil() -} - -func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { - if e.config.BlockInbound { - log.Infof("SSH server is disabled because inbound connections are blocked") - return nil - } - - if !e.config.ServerSSHAllowed { - log.Info("SSH server is not enabled") - return nil - } - - if sshConf.GetSshEnabled() { - if runtime.GOOS == "windows" { - log.Warnf("running SSH server on %s is not supported", runtime.GOOS) - return nil - } - // start SSH server if it wasn't running - if isNil(e.sshServer) { - listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) - if nbnetstack.IsEnabled() { - listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) - } - // nil sshServer means it has not yet been started - var err error - e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr) - - if err != nil { - return fmt.Errorf("create ssh server: %w", err) - } - go func() { - // blocking - err = e.sshServer.Start() - if err != nil { - // will throw error when we stop it even if it is a graceful stop - log.Debugf("stopped SSH server with error %v", err) - } - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - e.sshServer = nil - log.Infof("stopped SSH server") - }() - } else { - log.Debugf("SSH server is already running") - } - } else if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) - } - e.sshServer = nil - } - return nil -} - func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { if e.wgInterface == nil { return errors.New("wireguard interface is not initialized") @@ -930,8 +945,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { } if conf.GetSshConfig() != nil { - err := e.updateSSH(conf.GetSshConfig()) - if err != nil { + if err := e.updateSSH(conf.GetSshConfig()); err != nil { log.Warnf("failed handling SSH server setup: %v", err) } } @@ -950,7 +964,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -967,6 +983,11 @@ func (e *Engine) receiveManagementEvents() { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) @@ -1060,10 +1081,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1084,7 +1109,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1121,17 +1146,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() - // update SSHServer by adding remote peer SSH keys - if !isNil(e.sshServer) { - for _, config := range networkMap.GetRemotePeers() { - if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { - err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey())) - if err != nil { - log.Warnf("failed adding authorized key to SSH DefaultServer %v", err) - } - } - } + e.updatePeerSSHHostKeys(networkMap.GetRemotePeers()) + + if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { + log.Warnf("failed to update SSH client config: %v", err) } + + e.updateSSHServerAuth(networkMap.GetSshAuth()) } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1208,15 +1229,24 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + //nolint + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { dnsZone := nbdns.CustomZone{ - Domain: zone.GetDomain(), + Domain: zone.GetDomain(), + SearchDomainDisabled: zone.GetSearchDomainDisabled(), + SkipPTRProcess: zone.GetSkipPTRProcess(), } for _, record := range zone.Records { dnsRecord := nbdns.SimpleRecord{ @@ -1368,12 +1398,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + conn, ok := e.peerStore.PeerConn(msg.Key) if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) @@ -1485,13 +1522,6 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } - if !isNil(e.sshServer) { - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed stopping the SSH server: %v", err) - } - } - if e.firewall != nil { err := e.firewall.Close(e.stateManager) if err != nil { @@ -1522,6 +1552,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) netMap, err := e.mgmClient.GetNetworkMap(info) @@ -1667,7 +1702,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1699,8 +1734,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1717,15 +1756,10 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1747,7 +1781,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1757,8 +1793,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } @@ -1839,64 +1875,66 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +func (e *Engine) RenewTun(fd int) error { + e.syncMsgMux.Lock() + wgInterface := e.wgInterface + e.syncMsgMux.Unlock() + + if wgInterface == nil { + return fmt.Errorf("wireguard interface not initialized") + } + + return wgInterface.RenewTun(fd) +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - - default: + if e.dnsForwardMgr == nil { + e.startDNSForwarder(fwdEntries) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil + e.stopDNSForwarder() } - } -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.dnsForwardMgr = nil } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go new file mode 100644 index 000000000..e683d8cee --- /dev/null +++ b/client/internal/engine_ssh.go @@ -0,0 +1,393 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net/netip" + "strings" + + log "github.com/sirupsen/logrus" + + firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +type sshServer interface { + Start(ctx context.Context, addr netip.AddrPort) error + Stop() error + GetStatus() (bool, []sshserver.SessionInfo) + UpdateSSHAuth(config *sshauth.Config) +} + +func (e *Engine) setupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("add SSH port redirection: %w", err) + } + log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { + if e.config.BlockInbound { + log.Info("SSH server is disabled because inbound connections are blocked") + return e.stopSSHServer() + } + + if !e.config.ServerSSHAllowed { + log.Info("SSH server is disabled in config") + return e.stopSSHServer() + } + + if !sshConf.GetSshEnabled() { + if e.config.ServerSSHAllowed { + log.Info("SSH server is locally allowed but disabled by management server") + } + return e.stopSSHServer() + } + + if e.sshServer != nil { + log.Debug("SSH server is already running") + return nil + } + + if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth { + log.Info("starting SSH server without JWT authentication (authentication disabled by config)") + return e.startSSHServer(nil) + } + + if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil { + jwtConfig := &sshserver.JWTConfig{ + Issuer: protoJWT.GetIssuer(), + Audience: protoJWT.GetAudience(), + KeysLocation: protoJWT.GetKeysLocation(), + MaxTokenAge: protoJWT.GetMaxTokenAge(), + } + + return e.startSSHServer(jwtConfig) + } + + return errors.New("SSH server requires valid JWT configuration") +} + +// updateSSHClientConfig updates the SSH client configuration with peer information +func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { + peerInfo := e.extractPeerSSHInfo(remotePeers) + if len(peerInfo) == 0 { + log.Debug("no SSH-enabled peers found, skipping SSH config update") + return nil + } + + configMgr := sshconfig.New() + if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil { + log.Warnf("failed to update SSH client config: %v", err) + return nil // Don't fail engine startup on SSH config issues + } + + log.Debugf("updated SSH client config with %d peers", len(peerInfo)) + + if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{ + SSHConfigDir: configMgr.GetSSHConfigDir(), + SSHConfigFile: configMgr.GetSSHConfigFile(), + }); err != nil { + log.Warnf("failed to update SSH config state: %v", err) + } + + return nil +} + +// extractPeerSSHInfo extracts SSH information from peer configurations +func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo { + var peerInfo []sshconfig.PeerSSHInfo + + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + peerIP := e.extractPeerIP(peerConfig) + hostname := e.extractHostname(peerConfig) + + peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ + Hostname: hostname, + IP: peerIP, + FQDN: peerConfig.GetFqdn(), + }) + } + + return peerInfo +} + +// extractPeerIP extracts IP address from peer's allowed IPs +func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { + if len(peerConfig.GetAllowedIps()) == 0 { + return "" + } + + if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { + return prefix.Addr().String() + } + return "" +} + +// extractHostname extracts short hostname from FQDN +func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { + fqdn := peerConfig.GetFqdn() + if fqdn == "" { + return "" + } + + parts := strings.Split(fqdn, ".") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} + +// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access +func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) { + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil { + log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err) + } + } + + log.Debugf("updated peer SSH host keys for daemon API access") +} + +// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN +func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { + e.syncMsgMux.Lock() + statusRecorder := e.statusRecorder + e.syncMsgMux.Unlock() + + if statusRecorder == nil { + return nil, false + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if len(peerState.SSHHostKey) > 0 { + return peerState.SSHHostKey, true + } + return nil, false + } + } + + return nil, false +} + +// cleanupSSHConfig removes NetBird SSH client configuration on shutdown +func (e *Engine) cleanupSSHConfig() { + configMgr := sshconfig.New() + + if err := configMgr.RemoveSSHClientConfig(); err != nil { + log.Warnf("failed to remove SSH client config: %v", err) + } else { + log.Debugf("SSH client config cleanup completed") + } +} + +// startSSHServer initializes and starts the SSH server with proper configuration. +func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { + if e.wgInterface == nil { + return errors.New("wg interface not initialized") + } + + serverConfig := &sshserver.Config{ + HostKeyPEM: e.config.SSHKey, + JWT: jwtConfig, + } + server := sshserver.New(serverConfig) + + wgAddr := e.wgInterface.Address() + server.SetNetworkValidation(wgAddr) + + netbirdIP := wgAddr.IP + listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort) + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + server.SetNetstackNet(netstackNet) + } + + e.configureSSHServer(server) + + if err := server.Start(e.ctx, listenAddr); err != nil { + return fmt.Errorf("start SSH server: %w", err) + } + + e.sshServer = server + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + if err := e.setupSSHPortRedirection(); err != nil { + log.Warnf("failed to setup SSH port redirection: %v", err) + } + + return nil +} + +// configureSSHServer applies SSH configuration options to the server. +func (e *Engine) configureSSHServer(server *sshserver.Server) { + if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot { + server.SetAllowRootLogin(true) + log.Info("SSH root login enabled") + } else { + server.SetAllowRootLogin(false) + log.Info("SSH root login disabled (default)") + } + + if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP { + server.SetAllowSFTP(true) + log.Info("SSH SFTP subsystem enabled") + } else { + server.SetAllowSFTP(false) + log.Info("SSH SFTP subsystem disabled (default)") + } + + if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding { + server.SetAllowLocalPortForwarding(true) + log.Info("SSH local port forwarding enabled") + } else { + server.SetAllowLocalPortForwarding(false) + log.Info("SSH local port forwarding disabled (default)") + } + + if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding { + server.SetAllowRemotePortForwarding(true) + log.Info("SSH remote port forwarding enabled") + } else { + server.SetAllowRemotePortForwarding(false) + log.Info("SSH remote port forwarding disabled (default)") + } +} + +func (e *Engine) cleanupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("remove SSH port redirection: %w", err) + } + log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) stopSSHServer() error { + if e.sshServer == nil { + return nil + } + + if err := e.cleanupSSHPortRedirection(); err != nil { + log.Warnf("failed to cleanup SSH port redirection: %v", err) + } + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + log.Info("stopping SSH server") + err := e.sshServer.Stop() + e.sshServer = nil + if err != nil { + return fmt.Errorf("stop: %w", err) + } + return nil +} + +// GetSSHServerStatus returns the SSH server status and active sessions +func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) { + e.syncMsgMux.Lock() + sshServer := e.sshServer + e.syncMsgMux.Unlock() + + if sshServer == nil { + return false, nil + } + + return sshServer.GetStatus() +} + +// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server +func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) { + if sshAuth == nil { + return + } + + if e.sshServer == nil { + return + } + + protoUsers := sshAuth.GetAuthorizedUsers() + authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers)) + for i, hash := range protoUsers { + if len(hash) != 16 { + log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash)) + return + } + authorizedUsers[i] = sshuserhash.UserIDHash(hash) + } + + machineUsers := make(map[string][]uint32) + for osUser, indexes := range sshAuth.GetMachineUsers() { + machineUsers[osUser] = indexes.GetIndexes() + } + + // Update SSH server with new authorization configuration + authConfig := &sshauth.Config{ + UserIDClaim: sshAuth.GetUserIDClaim(), + AuthorizedUsers: authorizedUsers, + MachineUsers: machineUsers, + } + + e.sshServer.UpdateSSHAuth(authConfig) +} diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go index 9e171b0b2..1ebb5779c 100644 --- a/client/internal/engine_stdnet.go +++ b/client/internal/engine_stdnet.go @@ -7,5 +7,5 @@ import ( ) func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(e.config.IFaceBlackList) + return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList) } diff --git a/client/internal/engine_stdnet_android.go b/client/internal/engine_stdnet_android.go index 68a0ae719..de3c80bcf 100644 --- a/client/internal/engine_stdnet_android.go +++ b/client/internal/engine_stdnet_android.go @@ -3,5 +3,5 @@ package internal import "github.com/netbirdio/netbird/client/internal/stdnet" func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) + return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2f1098100..26ea6f8c2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,6 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,11 +24,18 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" @@ -43,13 +49,12 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/routemanager" - "github.com/netbirdio/netbird/client/ssh" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -105,6 +110,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RenewTun(_ int) error { + return nil +} + func (m *MockWGIface) RemoveEndpointAddress(_ string) error { return nil } @@ -211,11 +220,13 @@ func TestMain(m *testing.M) { } func TestEngine_SSH(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipping TestEngine_SSH") + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return } - key, err := wgtypes.GeneratePrivateKey() + sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { t.Fatal(err) return @@ -237,45 +248,20 @@ func TestEngine_SSH(t *testing.T) { WgPort: 33100, ServerSSHAllowed: true, MTU: iface.DefaultMTU, + SSHKey: sshKey, }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, + nil, ) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } - var sshKeysAdded []string - var sshPeersRemoved []string - - sshCtx, cancel := context.WithCancel(context.Background()) - - engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) { - return &ssh.MockServer{ - Ctx: sshCtx, - StopFunc: func() error { - cancel() - return nil - }, - StartFunc: func() error { - <-ctx.Done() - return ctx.Err() - }, - AddAuthorizedKeyFunc: func(peer, newKey string) error { - sshKeysAdded = append(sshKeysAdded, newKey) - return nil - }, - RemoveAuthorizedKeyFunc: func(peer string) { - sshPeersRemoved = append(sshPeersRemoved, peer) - }, - }, nil - } err = engine.Start(nil, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer func() { err := engine.Stop() @@ -301,9 +287,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) @@ -311,19 +295,24 @@ func TestEngine_SSH(t *testing.T) { networkMap = &mgmtProto.NetworkMap{ Serial: 7, PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24", - SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}}, + SshConfig: &mgmtProto.SSHConfig{ + SshEnabled: true, + JwtConfig: &mgmtProto.JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + MaxTokenAge: 3600, + }, + }}, RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH}, RemotePeersIsEmpty: false, } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ") // now remove peer networkMap = &mgmtProto.NetworkMap{ @@ -333,13 +322,10 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") // now disable SSH server networkMap = &mgmtProto.NetworkMap{ @@ -351,12 +337,70 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) +} +func TestEngine_SSHUpdateLogic(t *testing.T) { + // Test that SSH server start/stop logic works based on config + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, // Start with SSH disabled + }, + syncMsgMux: &sync.Mutex{}, + } + + // Test SSH disabled config + sshConfig := &mgmtProto.SSHConfig{SshEnabled: false} + err := engine.updateSSH(sshConfig) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + + // Test inbound blocked + engine.config.BlockInbound = true + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + engine.config.BlockInbound = false + + // Test with server SSH not allowed + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) +} + +func TestEngine_SSHServerConsistency(t *testing.T) { + + t.Run("server set only on successful creation", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: true, + SSHKey: []byte("test-key"), + }, + syncMsgMux: &sync.Mutex{}, + } + + engine.wgInterface = nil + + err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + + assert.Error(t, err) + assert.Nil(t, engine.sshServer) + }) + + t.Run("cleanup handles nil gracefully", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, + }, + syncMsgMux: &sync.Mutex{}, + } + + err := engine.stopSSHServer() + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + }) } func TestEngine_UpdateNetworkMap(t *testing.T) { @@ -371,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - MTU: iface.DefaultMTU, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -604,7 +640,7 @@ func TestEngine_Sync(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -769,9 +805,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -971,10 +1007,10 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -1497,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil e.ctx = ctx return e, err } @@ -1556,7 +1592,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -1584,13 +1619,19 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 98fe01912..90b06cbd1 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -20,6 +20,7 @@ import ( type wgIfaceBase interface { Create() error CreateOnAndroid(routeRange []string, ip string, domains []string) error + RenewTun(fd int) error IsUserspaceBind() bool Name() string Address() wgaddr.Address diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } diff --git a/client/internal/login.go b/client/internal/login.go index 257e3c3ac..f528783ef 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) return serverKey, loginResp, err @@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index f5eb2c739..b3482f54e 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -1,4 +1,4 @@ -//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package networkmonitor @@ -6,21 +6,19 @@ import ( "context" "errors" "fmt" - "syscall" - "unsafe" log "github.com/sirupsen/logrus" - "golang.org/x/net/route" "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + fd, err := prepareFd() if err != nil { return fmt.Errorf("open routing socket: %v", err) } + defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { @@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } }() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) - if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) - continue - } - - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - - switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } - } - } - } - } -} - -func parseRouteMessage(buf []byte) (*systemops.Route, error) { - msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.RouteMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return systemops.MsgToRoute(msg) + return routeCheck(ctx, fd, nexthopv4, nexthopv6) } diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go new file mode 100644 index 000000000..c287236e8 --- /dev/null +++ b/client/internal/networkmonitor/check_change_common.go @@ -0,0 +1,92 @@ +//go:build dragonfly || freebsd || netbsd || openbsd || darwin + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func prepareFd() (int, error) { + return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) +} + +func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Warnf("Network monitor: failed to read from routing socket: %v", err) + } + continue + } + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Debugf("Network monitor: error parsing routing message: %v", err) + continue + } + + if route.Dst.Bits() != 0 { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + return nil + } + } + } + } + } +} + +func parseRouteMessage(buf []byte) (*systemops.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return systemops.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go new file mode 100644 index 000000000..ddc6e1736 --- /dev/null +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -0,0 +1,149 @@ +//go:build darwin && !ios + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// todo: refactor to not use static functions + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + fd, err := prepareFd() + if err != nil { + return fmt.Errorf("open routing socket: %v", err) + } + + defer func() { + if err := unix.Close(fd); err != nil { + if !errors.Is(err, unix.EBADF) { + log.Warnf("Network monitor: failed to close routing socket: %v", err) + } + } + }() + + routeChanged := make(chan struct{}) + go func() { + _ = routeCheck(ctx, fd, nexthopv4, nexthopv6) + close(routeChanged) + }() + + wakeUp := make(chan struct{}) + go func() { + wakeUpListen(ctx) + close(wakeUp) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-routeChanged: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("route change detected") + return nil + case <-wakeUp: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("wakeup detected") + return nil + } +} + +func wakeUpListen(ctx context.Context) { + log.Infof("start to watch for system wakeups") + var ( + initialHash uint32 + err error + ) + + // Keep retrying until initial sysctl succeeds or context is canceled + for { + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + default: + initialHash, err = readSleepTimeHash() + if err != nil { + log.Errorf("failed to detect initial sleep time: %v", err) + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + case <-time.After(3 * time.Second): + continue + } + } + log.Debugf("initial wakeup hash: %d", initialHash) + break + } + break + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("context canceled, stopping wakeUpListen") + return + + case <-ticker.C: + newHash, err := readSleepTimeHash() + if err != nil { + log.Errorf("failed to read sleep time hash: %v", err) + continue + } + + if newHash == initialHash { + log.Tracef("no wakeup detected") + continue + } + + upOut, err := exec.Command("uptime").Output() + if err != nil { + log.Errorf("failed to run uptime command: %v", err) + upOut = []byte("unknown") + } + log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut) + return + } + } +} + +func readSleepTimeHash() (uint32, error) { + cmd := exec.Command("sysctl", "kern.sleeptime") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("failed to run sysctl: %w", err) + } + + h, err := hash(out) + if err != nil { + return 0, fmt.Errorf("failed to compute hash: %w", err) + } + + return h, nil +} + +func hash(data []byte) (uint32, error) { + hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher + if _, err := hasher.Write(data); err != nil { + return 0, err + } + return hasher.Sum32(), nil +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index accdd9c9d..6d019258d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { event := make(chan struct{}, 1) go nw.checkChanges(ctx, event, nexthop4, nexthop6) + log.Infof("start watching for network changes") // debounce changes timer := time.NewTimer(0) timer.Stop() diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 68afe986a..426c31e1a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { } }() - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { + if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { return false } diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go index 39cb95591..52d66159c 100644 --- a/client/internal/peer/endpoint.go +++ b/client/internal/peer/endpoint.go @@ -20,7 +20,7 @@ type EndpointUpdater struct { wgConfig WgConfig initiator bool - // mu protects updateWireGuardPeer and cancelFunc + // mu protects cancelFunc mu sync.Mutex cancelFunc func() updateWg sync.WaitGroup @@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U case <-ctx.Done(): return case <-t.C: - e.mu.Lock() if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) } - e.mu.Unlock() } } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 32a458d00..7f500c410 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -2,6 +2,7 @@ package peer import ( "os" + "runtime" "strings" ) @@ -10,5 +11,8 @@ const ( ) func isForceRelayed() bool { + if runtime.GOOS == "js" { + return true + } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 0f22ee7b0..a201dd095 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { log.Debugf("Gathering ICE candidates") - agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) if err != nil { return false, fmt.Errorf("create ICE agent: %w", err) } diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..6f4f5ad4f 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,10 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index e80c98884..79f68d279 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,6 +1,7 @@ package ice import ( + "context" "sync" "time" @@ -22,6 +23,8 @@ const ( iceFailedTimeoutDefault = 6 * time.Second // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package iceRelayAcceptanceMinWaitDefault = 2 * time.Second + // iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete + iceAgentCloseTimeout = 3 * time.Second ) type ThreadSafeAgent struct { @@ -32,18 +35,28 @@ type ThreadSafeAgent struct { func (a *ThreadSafeAgent) Close() error { var err error a.once.Do(func() { - err = a.Agent.Close() + done := make(chan error, 1) + go func() { + done <- a.Agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } }) return err } -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { +func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList) if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } diff --git a/client/internal/peer/ice/stdnet.go b/client/internal/peer/ice/stdnet.go index 3ce83727e..685ed0363 100644 --- a/client/internal/peer/ice/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -3,9 +3,11 @@ package ice import ( + "context" + "github.com/netbirdio/netbird/client/internal/stdnet" ) -func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNet(ifaceBlacklist) +func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ctx, ifaceBlacklist) } diff --git a/client/internal/peer/ice/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go index 84c665e6f..5033ec1b9 100644 --- a/client/internal/peer/ice/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,7 +1,11 @@ package ice -import "github.com/netbirdio/netbird/client/internal/stdnet" +import ( + "context" -func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 239cce7e0..76f4f523c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,9 +21,9 @@ import ( "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" relayClient "github.com/netbirdio/netbird/shared/relay/client" - "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -67,6 +67,7 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool + SSHHostKey []byte routes map[string]struct{} } @@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { return nil } +// UpdatePeerSSHHostKey updates peer's SSH host key +func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peerPubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.SSHHostKey = sshHostKey + d.peers[peerPubKey] = peerState + + return nil +} + // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 3675f0157..840fc9241 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -209,7 +209,7 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { - agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } @@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { if isController(w.config) { - return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a713bb342..23c92e8af 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct { DisablePromptLogin bool // LoginFlag is used to configure the PKCE flow login behavior LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..84ee73902 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "os/user" "path/filepath" "reflect" "runtime" @@ -44,24 +45,30 @@ var DefaultInterfaceBlacklist = []string{ // ConfigInput carries configuration changes to the client type ConfigInput struct { - ManagementURL string - AdminURL string - ConfigPath string - StateFilePath string - PreSharedKey *string - ServerSSHAllowed *bool - NATExternalIPs []string - CustomDNSAddress []byte - RosenpassEnabled *bool - RosenpassPermissive *bool - InterfaceName *string - WireguardPort *int - NetworkMonitor *bool - DisableAutoConnect *bool - ExtraIFaceBlackList []string - DNSRouteInterval *time.Duration - ClientCertPath string - ClientCertKeyPath string + ManagementURL string + AdminURL string + ConfigPath string + StateFilePath string + PreSharedKey *string + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int + NATExternalIPs []string + CustomDNSAddress []byte + RosenpassEnabled *bool + RosenpassPermissive *bool + InterfaceName *string + WireguardPort *int + NetworkMonitor *bool + DisableAutoConnect *bool + ExtraIFaceBlackList []string + DNSRouteInterval *time.Duration + ClientCertPath string + ClientCertKeyPath string DisableClientRoutes *bool DisableServerRoutes *bool @@ -82,18 +89,24 @@ type ConfigInput struct { // Config Configuration type type Config struct { // Wireguard private key of local peer - PrivateKey string - PreSharedKey string - ManagementURL *url.URL - AdminURL *url.URL - WgIface string - WgPort int - NetworkMonitor *bool - IFaceBlackList []string - DisableIPv6Discovery bool - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed *bool + PrivateKey string + PreSharedKey string + ManagementURL *url.URL + AdminURL *url.URL + WgIface string + WgPort int + NetworkMonitor *bool + IFaceBlackList []string + DisableIPv6Discovery bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int DisableClientRoutes bool DisableServerRoutes bool @@ -153,19 +166,26 @@ func getConfigDir() (string, error) { if ConfigDirOverride != "" { return ConfigDirOverride, nil } - configDir, err := os.UserConfigDir() + + base, err := baseConfigDir() if err != nil { return "", err } - configDir = filepath.Join(configDir, "netbird") - if _, err := os.Stat(configDir); os.IsNotExist(err) { - if err := os.MkdirAll(configDir, 0755); err != nil { - return "", err + configDir := filepath.Join(base, "netbird") + if err := os.MkdirAll(configDir, 0o755); err != nil { + return "", err + } + return configDir, nil +} + +func baseConfigDir() (string, error) { + if runtime.GOOS == "darwin" { + if u, err := user.Current(); err == nil && u.HomeDir != "" { + return filepath.Join(u.HomeDir, "Library", "Application Support"), nil } } - - return configDir, nil + return os.UserConfigDir() } func getConfigDirForUser(username string) (string, error) { @@ -195,6 +215,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { @@ -375,6 +396,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot { + if *input.EnableSSHRoot { + log.Infof("enabling SSH root login") + } else { + log.Infof("disabling SSH root login") + } + config.EnableSSHRoot = input.EnableSSHRoot + updated = true + } + + if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP { + if *input.EnableSSHSFTP { + log.Infof("enabling SSH SFTP subsystem") + } else { + log.Infof("disabling SSH SFTP subsystem") + } + config.EnableSSHSFTP = input.EnableSSHSFTP + updated = true + } + + if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding { + if *input.EnableSSHLocalPortForwarding { + log.Infof("enabling SSH local port forwarding") + } else { + log.Infof("disabling SSH local port forwarding") + } + config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding + updated = true + } + + if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding { + if *input.EnableSSHRemotePortForwarding { + log.Infof("enabling SSH remote port forwarding") + } else { + log.Infof("disabling SSH remote port forwarding") + } + config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding + updated = true + } + + if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth { + if *input.DisableSSHAuth { + log.Infof("disabling SSH authentication") + } else { + log.Infof("enabling SSH authentication") + } + config.DisableSSHAuth = input.DisableSSHAuth + updated = true + } + + if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL { + log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL) + config.SSHJWTCacheTTL = input.SSHJWTCacheTTL + updated = true + } + if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { log.Infof("updating DNS route interval to %s (old value %s)", input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..ab13cf389 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go index fe0afae2b..c87f521cb 100644 --- a/client/internal/profilemanager/profilemanager.go +++ b/client/internal/profilemanager/profilemanager.go @@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error { return nil } + +// GetLoginHint retrieves the email from the active profile to use as login_hint. +func GetLoginHint() string { + pm := NewProfileManager() + activeProf, err := pm.GetActiveProfile() + if err != nil { + log.Debugf("failed to get active profile for login hint: %v", err) + return "" + } + + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + return "" + } + + return profileState.Email +} diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go index faccf5f68..5a0c14000 100644 --- a/client/internal/profilemanager/service.go +++ b/client/internal/profilemanager/service.go @@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) { } type ServiceManager struct { + profilesDir string // If set, overrides ConfigDirOverride for profile operations } func NewServiceManager(defaultConfigPath string) *ServiceManager { @@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager { return &ServiceManager{} } +// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory +// This allows setting the profiles directory without modifying the global ConfigDirOverride +func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager { + if defaultConfigPath != "" { + DefaultConfigPath = defaultConfigPath + } + return &ServiceManager{ + profilesDir: profilesDir, + } +} + func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) { if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil { @@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string { } func (s *ServiceManager) AddProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error { } func (s *ServiceManager) RemoveProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error { } func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return nil, fmt.Errorf("failed to get config directory: %w", err) } @@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string { return defaultStatePath } - configDir, err := getConfigDirForUser(activeProf.Username) + configDir, err := s.getConfigDir(activeProf.Username) if err != nil { log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err) return defaultStatePath @@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string { return filepath.Join(configDir, activeProf.Name+".state.json") } + +// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser +func (s *ServiceManager) getConfigDir(username string) (string, error) { + if s.profilesDir != "" { + return s.profilesDir, nil + } + + return getConfigDirForUser(username) +} diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..59be5b0a7 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,15 +33,171 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -119,7 +286,7 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -18,7 +19,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 587e05c74..8d1398a7a 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -18,8 +18,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const ( diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..2baa0e668 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -37,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/client/net" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/version" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -78,6 +81,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -101,12 +105,13 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) @@ -130,6 +135,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,9 +276,15 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -345,6 +357,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { @@ -474,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -516,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial, diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index d2f02526c..3697545ae 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -6,7 +6,7 @@ import ( "net/netip" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stretchr/testify/require" @@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..01916fbe3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -15,7 +15,7 @@ import ( "syscall" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen peerPrivateKey, err := wgtypes.GeneratePrivateKey() require.NoError(t, err) - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) require.NoError(t, err) opts := iface.WGIFaceOpts{ @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..61c8bbc79 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,8 +9,6 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) @@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) return isSelected } diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go new file mode 100644 index 000000000..3d6747ed1 --- /dev/null +++ b/client/internal/sleep/detector_darwin.go @@ -0,0 +1,218 @@ +//go:build darwin && !ios + +package sleep + +/* +#cgo LDFLAGS: -framework IOKit -framework CoreFoundation +#include +#include +#include + +extern void sleepCallbackBridge(); +extern void poweredOnCallbackBridge(); +extern void suspendedCallbackBridge(); +extern void resumedCallbackBridge(); + + +// C global variables for IOKit state +static IONotificationPortRef g_notifyPortRef = NULL; +static io_object_t g_notifierObject = 0; +static io_object_t g_generalInterestNotifier = 0; +static io_connect_t g_rootPort = 0; +static CFRunLoopRef g_runLoop = NULL; + +static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) { + switch (messageType) { + case kIOMessageSystemWillSleep: + sleepCallbackBridge(); + IOAllowPowerChange(g_rootPort, (long)messageArgument); + break; + case kIOMessageSystemHasPoweredOn: + poweredOnCallbackBridge(); + break; + case kIOMessageServiceIsSuspended: + suspendedCallbackBridge(); + break; + case kIOMessageServiceIsResumed: + resumedCallbackBridge(); + break; + default: + break; + } +} + +static void registerNotifications() { + g_rootPort = IORegisterForSystemPower( + NULL, + &g_notifyPortRef, + (IOServiceInterestCallback)sleepCallback, + &g_notifierObject + ); + + if (g_rootPort == 0) { + return; + } + + CFRunLoopAddSource(CFRunLoopGetCurrent(), + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + g_runLoop = CFRunLoopGetCurrent(); + CFRunLoopRun(); +} + +static void unregisterNotifications() { + CFRunLoopRemoveSource(g_runLoop, + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + IODeregisterForSystemPower(&g_notifierObject); + IOServiceClose(g_rootPort); + IONotificationPortDestroy(g_notifyPortRef); + CFRunLoopStop(g_runLoop); + + g_notifyPortRef = NULL; + g_notifierObject = 0; + g_rootPort = 0; + g_runLoop = NULL; +} + +*/ +import "C" + +import ( + "context" + "fmt" + "runtime" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + serviceRegistry = make(map[*Detector]struct{}) + serviceRegistryMu sync.Mutex +) + +//export sleepCallbackBridge +func sleepCallbackBridge() { + log.Info("sleepCallbackBridge event triggered") + + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeSleep) + } +} + +//export resumedCallbackBridge +func resumedCallbackBridge() { + log.Info("resumedCallbackBridge event triggered") +} + +//export suspendedCallbackBridge +func suspendedCallbackBridge() { + log.Info("suspendedCallbackBridge event triggered") +} + +//export poweredOnCallbackBridge +func poweredOnCallbackBridge() { + log.Info("poweredOnCallbackBridge event triggered") + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeWakeUp) + } +} + +type Detector struct { + callback func(event EventType) + ctx context.Context + cancel context.CancelFunc +} + +func NewDetector() (*Detector, error) { + return &Detector{}, nil +} + +func (d *Detector) Register(callback func(event EventType)) error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + if _, exists := serviceRegistry[d]; exists { + return fmt.Errorf("detector service already registered") + } + + d.callback = callback + + d.ctx, d.cancel = context.WithCancel(context.Background()) + + if len(serviceRegistry) > 0 { + serviceRegistry[d] = struct{}{} + return nil + } + + serviceRegistry[d] = struct{}{} + + // CFRunLoop must run on a single fixed OS thread + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + C.registerNotifications() + }() + + log.Info("sleep detection service started on macOS") + return nil +} + +// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down +// and the runloop is stopped and cleaned up. +func (d *Detector) Deregister() error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + _, exists := serviceRegistry[d] + if !exists { + return nil + } + + // cancel and remove this detector + d.cancel() + delete(serviceRegistry, d) + + // If other Detectors still exist, leave IOKit running + if len(serviceRegistry) > 0 { + return nil + } + + log.Info("sleep detection service stopping (deregister)") + + // Deregister IOKit notifications, stop runloop, and free resources + C.unregisterNotifications() + + return nil +} + +func (d *Detector) triggerCallback(event EventType) { + doneChan := make(chan struct{}) + + timeout := time.NewTimer(500 * time.Millisecond) + defer timeout.Stop() + + cb := d.callback + go func(callback func(event EventType)) { + log.Info("sleep detection event fired") + callback(event) + close(doneChan) + }(cb) + + select { + case <-doneChan: + case <-d.ctx.Done(): + case <-timeout.C: + log.Warnf("sleep callback timed out") + } +} diff --git a/client/internal/sleep/detector_notsupported.go b/client/internal/sleep/detector_notsupported.go new file mode 100644 index 000000000..6323bf5d1 --- /dev/null +++ b/client/internal/sleep/detector_notsupported.go @@ -0,0 +1,9 @@ +//go:build !darwin || ios + +package sleep + +import "fmt" + +func NewDetector() (detector, error) { + return nil, fmt.Errorf("sleep not supported on this platform") +} diff --git a/client/internal/sleep/service.go b/client/internal/sleep/service.go new file mode 100644 index 000000000..196a33f52 --- /dev/null +++ b/client/internal/sleep/service.go @@ -0,0 +1,37 @@ +package sleep + +var ( + EventTypeUnknown EventType = 0 + EventTypeSleep EventType = 1 + EventTypeWakeUp EventType = 2 +) + +type EventType int + +type detector interface { + Register(callback func(eventType EventType)) error + Deregister() error +} + +type Service struct { + detector detector +} + +func New() (*Service, error) { + d, err := NewDetector() + if err != nil { + return nil, err + } + + return &Service{ + detector: d, + }, nil +} + +func (s *Service) Register(callback func(eventType EventType)) error { + return s.detector.Register(callback) +} + +func (s *Service) Deregister() error { + return s.detector.Deregister() +} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 4b031c05c..381886ac6 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -4,17 +4,28 @@ package stdnet import ( + "context" + "errors" "fmt" + "net" + "net/netip" "slices" + "strconv" "sync" "time" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + + "github.com/netbirdio/netbird/client/iface/netstack" ) -const updateInterval = 30 * time.Second +const ( + updateInterval = 30 * time.Second + dnsResolveTimeout = 30 * time.Second +) + +var errNoSuitableAddress = errors.New("no suitable address found") // Net is an implementation of the net.Net interface // based on functions of the standard net package. @@ -28,12 +39,19 @@ type Net struct { // mu is shared between interfaces and lastUpdate mu sync.Mutex + + // ctx is the context for network operations that supports cancellation + ctx context.Context } // NewNetWithDiscover creates a new StdNet instance. -func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { +func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client // so in android cli use pionDiscover @@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri } // NewNet creates a new StdNet instance. -func NewNet(disallowList []string) (*Net, error) { +func NewNet(ctx context.Context, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ iFaceDiscover: pionDiscover{}, interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } return n, n.UpdateInterfaces() } +// resolveAddr performs DNS resolution with context support and timeout. +func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return netip.AddrPort{}, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err) + } + if port < 0 || port > 65535 { + return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port) + } + + ipNet := "ip" + switch network { + case "tcp4", "udp4": + ipNet = "ip4" + case "tcp6", "udp6": + ipNet = "ip6" + } + + if host == "" { + addr := netip.IPv4Unspecified() + if ipNet == "ip6" { + addr = netip.IPv6Unspecified() + } + return netip.AddrPortFrom(addr, uint16(port)), nil + } + + ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host) + if err != nil { + return netip.AddrPort{}, err + } + + if len(addrs) == 0 { + return netip.AddrPort{}, errNoSuitableAddress + } + + return netip.AddrPortFrom(addrs[0], uint16(port)), nil +} + // UpdateInterfaces updates the internal list of network interfaces // and associated addresses filtering them by name. // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one @@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I } return result } + +// ResolveUDPAddr resolves UDP addresses with context support and timeout. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + switch network { + case "udp", "udp4", "udp6": + case "": + network = "udp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err} + } + + return net.UDPAddrFromAddrPort(addrPort), nil +} + +// ResolveTCPAddr resolves TCP addresses with context support and timeout. +func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + case "": + network = "tcp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err} + } + + return net.TCPAddrFromAddrPort(addrPort), nil +} diff --git a/client/internal/templates/pkce-auth-msg.html b/client/internal/templates/pkce-auth-msg.html index 4825c48e7..175a6f05c 100644 --- a/client/internal/templates/pkce-auth-msg.html +++ b/client/internal/templates/pkce-auth-msg.html @@ -1,88 +1,93 @@ + - + + + + NetBird Login + + + - NetBird Login Successful + -
- -
- {{ if .Error }} - - - - -
-
- Login failed +
+
+ + +
+ + + + + + + + + + + + + + + + + + +
- {{ .Error }}. -
- {{ else }} - - - - -
-
- Login successful + +
+ +
+ + {{ if .Error }} + +
+ + + + +
+ {{ else }} + +
+ + + + +
+ {{ end }} + + +
+ {{ if .Error }} +

Login Failed

+ {{ else }} +

Login Successful

+ {{ end }} +
+ + + {{ if .Error }} +
+ {{ .Error }} +
+ {{ else }} + +
+ Your device is now registered and logged in to NetBird. You can now close this window. +
+ {{ end }} + +
- Your device is now registered and logged in to NetBird. -
- You can now close this window.
- {{ end }}
+ diff --git a/client/internal/templates/pkce_auth_msg_test.go b/client/internal/templates/pkce_auth_msg_test.go new file mode 100644 index 000000000..75b1c9e76 --- /dev/null +++ b/client/internal/templates/pkce_auth_msg_test.go @@ -0,0 +1,299 @@ +package templates + +import ( + "html/template" + "os" + "path/filepath" + "testing" +) + +func TestPKCEAuthMsgTemplate(t *testing.T) { + tests := []struct { + name string + data map[string]string + outputFile string + expectedTitle string + expectedInContent []string + notExpectedInContent []string + }{ + { + name: "error_state", + data: map[string]string{ + "Error": "authentication failed: invalid state", + }, + outputFile: "pkce-auth-error.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication failed: invalid state", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + { + name: "success_state", + data: map[string]string{ + // No error field means success + }, + outputFile: "pkce-auth-success.html", + expectedTitle: "Login Successful", + expectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird. You can now close this window.", + }, + notExpectedInContent: []string{ + "Login Failed", + }, + }, + { + name: "error_state_timeout", + data: map[string]string{ + "Error": "authentication timeout: request expired after 5 minutes", + }, + outputFile: "pkce-auth-timeout.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication timeout: request expired after 5 minutes", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the template + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + // Create temp directory for this test + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, tt.outputFile) + + // Create output file + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + + // Execute the template + if err := tmpl.Execute(file, tt.data); err != nil { + file.Close() + t.Fatalf("Failed to execute template: %v", err) + } + file.Close() + + t.Logf("Generated test output: %s", outputPath) + + // Read the generated file + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + contentStr := string(content) + + // Verify file has content + if len(contentStr) == 0 { + t.Error("Output file is empty") + } + + // Verify basic HTML structure + basicElements := []string{ + "", + "", + "", + "NetBird", + } + + for _, elem := range basicElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + + // Verify expected title + if !contains(contentStr, tt.expectedTitle) { + t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle) + } + + // Verify expected content is present + for _, expected := range tt.expectedInContent { + if !contains(contentStr, expected) { + t.Errorf("Expected HTML to contain '%s', but it was not found", expected) + } + } + + // Verify unexpected content is not present + for _, notExpected := range tt.notExpectedInContent { + if contains(contentStr, notExpected) { + t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected) + } + } + }) + } +} + +func TestPKCEAuthMsgTemplateValidation(t *testing.T) { + // Test that the template can be parsed without errors + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + // Test with empty data + t.Run("empty_data", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "empty-data.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + if err := tmpl.Execute(file, nil); err != nil { + t.Errorf("Template execution with nil data failed: %v", err) + } + }) + + // Test with error data + t.Run("with_error", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "with-error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{ + "Error": "test error message", + } + if err := tmpl.Execute(file, data); err != nil { + t.Errorf("Template execution with error data failed: %v", err) + } + }) +} + +func TestPKCEAuthMsgTemplateContent(t *testing.T) { + // Test that the template contains expected elements + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + t.Run("success_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "success.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{} + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for success indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Successful", + "NetBird", + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) + + t.Run("error_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + errorMsg := "test error message" + data := map[string]string{ + "Error": errorMsg, + } + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for error indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Failed", + errorMsg, + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && containsHelper(s, substr))) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/client/internal/updatemanager/doc.go b/client/internal/updatemanager/doc.go new file mode 100644 index 000000000..54d1bdeab --- /dev/null +++ b/client/internal/updatemanager/doc.go @@ -0,0 +1,35 @@ +// Package updatemanager provides automatic update management for the NetBird client. +// It monitors for new versions, handles update triggers from management server directives, +// and orchestrates the download and installation of client updates. +// +// # Overview +// +// The update manager operates as a background service that continuously monitors for +// available updates and automatically initiates the update process when conditions are met. +// It integrates with the installer package to perform the actual installation. +// +// # Update Flow +// +// The complete update process follows these steps: +// +// 1. Manager receives update directive via SetVersion() or detects new version +// 2. Manager validates update should proceed (version comparison, rate limiting) +// 3. Manager publishes "updating" event to status recorder +// 4. Manager persists UpdateState to track update attempt +// 5. Manager downloads installer file (.msi or .exe) to temporary directory +// 6. Manager triggers installation via installer.RunInstallation() +// 7. Installer package handles the actual installation process +// 8. On next startup, CheckUpdateSuccess() verifies update completion +// 9. Manager publishes success/failure event to status recorder +// 10. Manager cleans up UpdateState +// +// # State Management +// +// Update state is persisted across restarts to track update attempts: +// +// - PreUpdateVersion: Version before update attempt +// - TargetVersion: Version attempting to update to +// +// This enables verification of successful updates and appropriate user notification +// after the client restarts with the new version. +package updatemanager diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updatemanager/downloader/downloader.go new file mode 100644 index 000000000..2ac36efed --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader.go @@ -0,0 +1,138 @@ +package downloader + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" +) + +const ( + userAgent = "NetBird agent installer/%s" + DefaultRetryDelay = 3 * time.Second +) + +func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error { + log.Debugf("starting download from %s", url) + + out, err := os.Create(dstFile) + if err != nil { + return fmt.Errorf("failed to create destination file %q: %w", dstFile, err) + } + defer func() { + if cerr := out.Close(); cerr != nil { + log.Warnf("error closing file %q: %v", dstFile, cerr) + } + }() + + // First attempt + err = downloadToFileOnce(ctx, url, out) + if err == nil { + log.Infof("successfully downloaded file to %s", dstFile) + return nil + } + + // If retryDelay is 0, don't retry + if retryDelay == 0 { + return err + } + + log.Warnf("download failed, retrying after %v: %v", retryDelay, err) + + // Sleep before retry + if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil { + return fmt.Errorf("download cancelled during retry delay: %w", sleepErr) + } + + // Truncate file before retry + if err := out.Truncate(0); err != nil { + return fmt.Errorf("failed to truncate file on retry: %w", err) + } + if _, err := out.Seek(0, 0); err != nil { + return fmt.Errorf("failed to seek to beginning of file: %w", err) + } + + // Second attempt + if err := downloadToFileOnce(ctx, url, out); err != nil { + return fmt.Errorf("download failed after retry: %w", err) + } + + log.Infof("successfully downloaded file to %s", dstFile) + return nil +} + +func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, limit)) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return data, nil +} + +func downloadToFileOnce(ctx context.Context, url string, out *os.File) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + if _, err := io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("failed to write response body to file: %w", err) + } + + return nil +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + select { + case <-time.After(duration): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updatemanager/downloader/downloader_test.go new file mode 100644 index 000000000..045db3a2d --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader_test.go @@ -0,0 +1,199 @@ +package downloader + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" +) + +const ( + retryDelay = 100 * time.Millisecond +) + +func TestDownloadToFile_Success(t *testing.T) { + // Create a test server that responds successfully + content := "test file content" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file + err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } +} + +func TestDownloadToFile_SuccessAfterRetry(t *testing.T) { + content := "test file content after retry" + var attemptCount atomic.Int32 + + // Create a test server that fails on first attempt, succeeds on second + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := attemptCount.Add(1) + if attempt == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should succeed after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil { + t.Fatalf("expected no error after retry, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } + + // Verify it took 2 attempts + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_FailsAfterRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should fail after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil { + t.Fatal("expected error after retry, got nil") + } + + // Verify it tried 2 times + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Create a context that will be cancelled during retry delay + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after a short delay (during the retry sleep) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + // Download the file (should fail due to context cancellation during retry) + err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile) + if err == nil { + t.Fatal("expected error due to context cancellation, got nil") + } + + // Should have only made 1 attempt (cancelled during retry delay) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_InvalidURL(t *testing.T) { + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile) + if err == nil { + t.Fatal("expected error for invalid URL, got nil") + } +} + +func TestDownloadToFile_InvalidDestination(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("test")) + })) + defer server.Close() + + // Use an invalid destination path + err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt") + if err == nil { + t.Fatal("expected error for invalid destination, got nil") + } +} + +func TestDownloadToFile_NoRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file with retryDelay = 0 (should not retry) + if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil { + t.Fatal("expected error, got nil") + } + + // Verify it only made 1 attempt (no retry) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updatemanager/installer/binary_nowindows.go new file mode 100644 index 000000000..19f3bef83 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_nowindows.go @@ -0,0 +1,7 @@ +//go:build !windows + +package installer + +func UpdaterBinaryNameWithoutExtension() string { + return updaterBinary +} diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updatemanager/installer/binary_windows.go new file mode 100644 index 000000000..4c66391c2 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_windows.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" + "strings" +) + +func UpdaterBinaryNameWithoutExtension() string { + ext := filepath.Ext(updaterBinary) + return strings.TrimSuffix(updaterBinary, ext) +} diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updatemanager/installer/doc.go new file mode 100644 index 000000000..0a60454bb --- /dev/null +++ b/client/internal/updatemanager/installer/doc.go @@ -0,0 +1,111 @@ +// Package installer provides functionality for managing NetBird application +// updates and installations across Windows, macOS. It handles +// the complete update lifecycle including artifact download, cryptographic verification, +// installation execution, process management, and result reporting. +// +// # Architecture +// +// The installer package uses a two-process architecture to enable self-updates: +// +// 1. Service Process: The main NetBird daemon process that initiates updates +// 2. Updater Process: A detached child process that performs the actual installation +// +// This separation is critical because: +// - The service binary cannot update itself while running +// - The installer (EXE/MSI/PKG) will terminate the service during installation +// - The updater process survives service termination and restarts it after installation +// - Results can be communicated back to the service after it restarts +// +// # Update Flow +// +// Service Process (RunInstallation): +// +// 1. Validates target version format (semver) +// 2. Determines installer type (EXE, MSI, PKG, or Homebrew) +// 3. Downloads installer file from GitHub releases (if applicable) +// 4. Verifies installer signature using reposign package (cryptographic verification in service process before +// launching updater) +// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows) +// 6. Launches updater process with detached mode: +// - --temp-dir: Temporary directory path +// - --service-dir: Service installation directory +// - --installer-file: Path to downloaded installer (if applicable) +// - --dry-run: Optional flag to test without actually installing +// 7. Service process continues running (will be terminated by installer later) +// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion +// +// Updater Process (Setup): +// +// 1. Receives parameters from service via command-line arguments +// 2. Runs installer with appropriate silent/quiet flags: +// - Windows EXE: installer.exe /S +// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log +// - macOS PKG: installer -pkg installer.pkg -target / +// - macOS Homebrew: brew upgrade netbirdio/tap/netbird +// 3. Installer terminates daemon and UI processes +// 4. Installer replaces binaries with new version +// 5. Updater waits for installer to complete +// 6. Updater restarts daemon: +// - Windows: netbird.exe service start +// - macOS/Linux: netbird service start +// 7. Updater restarts UI: +// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser +// - macOS: Uses launchctl asuser to launch NetBird.app for console user +// - Linux: Not implemented (UI typically auto-starts) +// 8. Updater writes result.json with success/error status +// 9. Updater process exits +// +// # Result Communication +// +// The ResultHandler (result.go) manages communication between updater and service: +// +// Result Structure: +// +// type Result struct { +// Success bool // true if installation succeeded +// Error string // error message if Success is false +// ExecutedAt time.Time // when installation completed +// } +// +// Result files are automatically cleaned up after being read. +// +// # File Locations +// +// Temporary Directory (platform-specific): +// +// Windows: +// - Path: %ProgramData%\Netbird\tmp-install +// - Example: C:\ProgramData\Netbird\tmp-install +// +// macOS: +// - Path: /var/lib/netbird/tmp-install +// - Requires root permissions +// +// Files created during installation: +// +// tmp-install/ +// installer.log +// updater[.exe] # Copy of service binary +// netbird_installer_*.[exe|msi|pkg] # Downloaded installer +// result.json # Installation result +// msi.log # MSI verbose log (Windows MSI only) +// +// # API Reference +// +// # Cleanup +// +// CleanUpInstallerFiles() removes temporary files after successful installation: +// - Downloaded installer files (*.exe, *.msi, *.pkg) +// - Updater binary copy +// - Does NOT remove result.json (cleaned by ResultHandler after read) +// - Does NOT remove msi.log (kept for debugging) +// +// # Dry-Run Mode +// +// Dry-run mode allows testing the update process without actually installing: +// +// Enable via environment variable: +// +// export NB_AUTO_UPDATE_DRY_RUN=true +// netbird service install-update 0.29.0 +package installer diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updatemanager/installer/installer.go new file mode 100644 index 000000000..caf5873f8 --- /dev/null +++ b/client/internal/updatemanager/installer/installer.go @@ -0,0 +1,50 @@ +//go:build !windows && !darwin + +package installer + +import ( + "context" + "fmt" +) + +const ( + updaterBinary = "updater" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{} +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +func (u *Installer) TempDir() string { + return "" +} + +func (c *Installer) LogFiles() []string { + return []string{} +} + +func (u *Installer) CleanUpInstallerFiles() error { + return nil +} + +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error { + return fmt.Errorf("unsupported platform") +} + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) { + return fmt.Errorf("unsupported platform") +} diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updatemanager/installer/installer_common.go new file mode 100644 index 000000000..03378d55f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_common.go @@ -0,0 +1,293 @@ +//go:build windows || darwin + +package installer + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + + "github.com/hashicorp/go-multierror" + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{ + tempDir: defaultTempDir, + } +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +// RunInstallation starts the updater process to run the installation +// This will run by the original service process +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) { + resultHandler := NewResultHandler(u.tempDir) + + defer func() { + if err != nil { + if writeErr := resultHandler.WriteErr(err); writeErr != nil { + log.Errorf("failed to write error result: %v", writeErr) + } + } + }() + + if err := validateTargetVersion(targetVersion); err != nil { + return err + } + + if err := u.mkTempDir(); err != nil { + return err + } + + var installerFile string + // Download files only when not using any third-party store + if installerType := TypeOfInstaller(ctx); installerType.Downloadable() { + log.Infof("download installer") + var err error + installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion) + if err != nil { + log.Errorf("failed to download installer: %v", err) + return err + } + + artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL) + if err != nil { + log.Errorf("failed to create artifact verify: %v", err) + return err + } + + if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil { + log.Errorf("artifact verification error: %v", err) + return err + } + } + + log.Infof("running installer") + updaterPath, err := u.copyUpdater() + if err != nil { + return err + } + + // the directory where the service has been installed + workspace, err := getServiceDir() + if err != nil { + return err + } + + args := []string{ + "--temp-dir", u.tempDir, + "--service-dir", workspace, + } + + if isDryRunEnabled() { + args = append(args, "--dry-run=true") + } + + if installerFile != "" { + args = append(args, "--installer-file", installerFile) + } + + updateCmd := exec.Command(updaterPath, args...) + log.Infof("starting updater process: %s", updateCmd.String()) + + // Configure the updater to run in a separate session/process group + // so it survives the parent daemon being stopped + setUpdaterProcAttr(updateCmd) + + // Start the updater process asynchronously + if err := updateCmd.Start(); err != nil { + return err + } + + pid := updateCmd.Process.Pid + log.Infof("updater started with PID %d", pid) + + // Release the process so the OS can fully detach it + if err := updateCmd.Process.Release(); err != nil { + log.Warnf("failed to release updater process: %v", err) + } + + return nil +} + +// CleanUpInstallerFiles +// - the installer file (pkg, exe, msi) +// - the selfcopy updater.exe +func (u *Installer) CleanUpInstallerFiles() error { + // Check if tempDir exists + info, err := os.Stat(u.tempDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + if !info.IsDir() { + return nil + } + + var merr *multierror.Error + + if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) { + merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err)) + } + + entries, err := os.ReadDir(u.tempDir) + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + for _, ext := range binaryExtensions { + if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) { + if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil { + merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err)) + } + break + } + } + } + + return merr.ErrorOrNil() +} + +func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) { + fileURL := urlWithVersionArch(installerType, targetVersion) + + // Clean up temp directory on error + var success bool + defer func() { + if !success { + if err := os.RemoveAll(u.tempDir); err != nil { + log.Errorf("error cleaning up temporary directory: %v", err) + } + } + }() + + fileName := path.Base(fileURL) + if fileName == "." || fileName == "/" || fileName == "" { + return "", fmt.Errorf("invalid file URL: %s", fileURL) + } + + outputFilePath := filepath.Join(u.tempDir, fileName) + if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil { + return "", err + } + + success = true + return outputFilePath, nil +} + +func (u *Installer) TempDir() string { + return u.tempDir +} + +func (u *Installer) mkTempDir() error { + if err := os.MkdirAll(u.tempDir, 0o755); err != nil { + log.Debugf("failed to create tempdir: %s", u.tempDir) + return err + } + return nil +} + +func (u *Installer) copyUpdater() (string, error) { + src, err := getServiceBinary() + if err != nil { + return "", fmt.Errorf("failed to get updater binary: %w", err) + } + + dst := filepath.Join(u.tempDir, updaterBinary) + if err := copyFile(src, dst); err != nil { + return "", fmt.Errorf("failed to copy updater binary: %w", err) + } + + if err := os.Chmod(dst, 0o755); err != nil { + return "", fmt.Errorf("failed to set permissions: %w", err) + } + + return dst, nil +} + +func validateTargetVersion(targetVersion string) error { + if targetVersion == "" { + return fmt.Errorf("target version cannot be empty") + } + + _, err := goversion.NewVersion(targetVersion) + if err != nil { + return fmt.Errorf("invalid target version %q: %w", targetVersion, err) + } + + return nil +} + +func copyFile(src, dst string) error { + log.Infof("copying %s to %s", src, dst) + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + defer func() { + if err := in.Close(); err != nil { + log.Warnf("failed to close source file: %v", err) + } + }() + + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("create destination: %w", err) + } + defer func() { + if err := out.Close(); err != nil { + log.Warnf("failed to close destination file: %v", err) + } + }() + + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy: %w", err) + } + + return nil +} + +func getServiceDir() (string, error) { + exePath, err := os.Executable() + if err != nil { + return "", err + } + return filepath.Dir(exePath), nil +} + +func getServiceBinary() (string, error) { + return os.Executable() +} + +func isDryRunEnabled() bool { + return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true") +} diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updatemanager/installer/installer_log_darwin.go new file mode 100644 index 000000000..50dd5d197 --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_darwin.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updatemanager/installer/installer_log_windows.go new file mode 100644 index 000000000..96e4cfd1f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_windows.go @@ -0,0 +1,12 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, msiLogFile), + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updatemanager/installer/installer_run_darwin.go new file mode 100644 index 000000000..462e2c227 --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_darwin.go @@ -0,0 +1,238 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + daemonName = "netbird" + updaterBinary = "updater" + uiBinary = "/Applications/NetBird.app" + + defaultTempDir = "/var/lib/netbird/tmp-install" + + pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg" +) + +var ( + binaryExtensions = []string{"pkg"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + + // skip service restart if dry-run mode is enabled + if dryRun { + return + } + + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + }() + + if dryRun { + time.Sleep(7 * time.Second) + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + switch TypeOfInstaller(ctx) { + case TypePKG: + resultErr = u.installPkgFile(ctx, installerFile) + case TypeHomebrew: + resultErr = u.updateHomeBrew(ctx) + } + + return resultErr +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Warnf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser() error { + log.Infof("starting netbird-ui: %s", uiBinary) + + // Get the current console user + cmd := exec.Command("stat", "-f", "%Su", "/dev/console") + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to get console user: %w", err) + } + + username := strings.TrimSpace(string(output)) + if username == "" || username == "root" { + return fmt.Errorf("no active user session found") + } + + log.Infof("starting UI for user: %s", username) + + // Get user's UID + userInfo, err := user.Lookup(username) + if err != nil { + return fmt.Errorf("failed to lookup user %s: %w", username, err) + } + + // Start the UI process as the console user using launchctl + // This ensures the app runs in the user's context with proper GUI access + launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary) + log.Infof("launchCmd: %s", launchCmd.String()) + // Set the user's home directory for proper macOS app behavior + launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir) + log.Infof("set HOME environment variable: %s", userInfo.HomeDir) + + if err := launchCmd.Start(); err != nil { + return fmt.Errorf("failed to start UI process: %w", err) + } + + // Release the process so it can run independently + if err := launchCmd.Process.Release(); err != nil { + log.Warnf("failed to release UI process: %v", err) + } + + log.Infof("netbird-ui started successfully for user %s", username) + return nil +} + +func (u *Installer) installPkgFile(ctx context.Context, path string) error { + log.Infof("installing pkg file: %s", path) + + // Kill any existing UI processes before installation + // This ensures the postinstall script's "open $APP" will start the new version + u.killUI() + + volume := "/" + + cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume) + if err := cmd.Start(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("installer started with PID %d", cmd.Process.Pid) + if err := cmd.Wait(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("pkg file installed successfully") + return nil +} + +func (u *Installer) updateHomeBrew(ctx context.Context) error { + log.Infof("updating homebrew") + + // Kill any existing UI processes before upgrade + // This ensures the new version will be started after upgrade + u.killUI() + + // Homebrew must be run as a non-root user + // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory + // Check both Apple Silicon and Intel Mac paths + brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath := "/opt/homebrew/bin/brew" + if _, err := os.Stat(brewTapPath); os.IsNotExist(err) { + // Try Intel Mac path + brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath = "/usr/local/bin/brew" + } + + fileInfo, err := os.Stat(brewTapPath) + if err != nil { + return fmt.Errorf("error getting homebrew installation path info: %w", err) + } + + fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys()) + } + + // Get username from UID + brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid)) + if err != nil { + return fmt.Errorf("error looking up brew installer user: %w", err) + } + userName := brewUser.Username + // Get user HOME, required for brew to run correctly + // https://github.com/Homebrew/brew/issues/15833 + homeDir := brewUser.HomeDir + + // Check if netbird-ui is installed (must run as the brew user, not root) + checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui") + checkUICmd.Env = append(os.Environ(), "HOME="+homeDir) + uiInstalled := checkUICmd.Run() == nil + + // Homebrew does not support installing specific versions + // Thus it will always update to latest and ignore targetVersion + upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"} + if uiInstalled { + upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui") + } + + cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...) + cmd.Env = append(os.Environ(), "HOME="+homeDir) + + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output)) + } + + log.Infof("homebrew updated successfully") + return nil +} + +func (u *Installer) killUI() { + log.Infof("killing existing netbird-ui processes") + cmd := exec.Command("pkill", "-x", "netbird-ui") + if output, err := cmd.CombinedOutput(); err != nil { + // pkill returns exit code 1 if no processes matched, which is fine + log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output)) + } else { + log.Infof("netbird-ui processes killed") + } +} + +func urlWithVersionArch(_ Type, version string) string { + url := strings.ReplaceAll(pkgDownloadURL, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updatemanager/installer/installer_run_windows.go new file mode 100644 index 000000000..353cd885d --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_windows.go @@ -0,0 +1,213 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + daemonName = "netbird.exe" + uiName = "netbird-ui.exe" + updaterBinary = "updater.exe" + + msiLogFile = "msi.log" + + msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi" + exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe" +) + +var ( + defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install") + + // for the cleanup + binaryExtensions = []string{"msi", "exe"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(daemonFolder); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + }() + + if dryRun { + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + installerType, err := typeByFileExtension(installerFile) + if err != nil { + log.Debugf("%v", err) + resultErr = err + return + } + + var cmd *exec.Cmd + switch installerType { + case TypeExe: + log.Infof("run exe installer: %s", installerFile) + cmd = exec.CommandContext(ctx, installerFile, "/S") + default: + installerDir := filepath.Dir(installerFile) + logPath := filepath.Join(installerDir, msiLogFile) + log.Infof("run msi installer: %s", installerFile) + cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath) + } + + cmd.Dir = filepath.Dir(installerFile) + + if resultErr = cmd.Start(); resultErr != nil { + log.Errorf("error starting installer: %v", resultErr) + return + } + + log.Infof("installer started with PID %d", cmd.Process.Pid) + if resultErr = cmd.Wait(); resultErr != nil { + log.Errorf("installer process finished with error: %v", resultErr) + return + } + + return nil +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Debugf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser(daemonFolder string) error { + uiPath := filepath.Join(daemonFolder, uiName) + log.Infof("starting netbird-ui: %s", uiPath) + + // Get the active console session ID + sessionID := windows.WTSGetActiveConsoleSessionId() + if sessionID == 0xFFFFFFFF { + return fmt.Errorf("no active user session found") + } + + // Get the user token for that session + var userToken windows.Token + err := windows.WTSQueryUserToken(sessionID, &userToken) + if err != nil { + return fmt.Errorf("failed to query user token: %w", err) + } + defer func() { + if err := userToken.Close(); err != nil { + log.Warnf("failed to close user token: %v", err) + } + }() + + // Duplicate the token to a primary token + var primaryToken windows.Token + err = windows.DuplicateTokenEx( + userToken, + windows.MAXIMUM_ALLOWED, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return fmt.Errorf("failed to duplicate token: %w", err) + } + defer func() { + if err := primaryToken.Close(); err != nil { + log.Warnf("failed to close token: %v", err) + } + }() + + // Prepare startup info + var si windows.StartupInfo + si.Cb = uint32(unsafe.Sizeof(si)) + si.Desktop = windows.StringToUTF16Ptr("winsta0\\default") + + var pi windows.ProcessInformation + + cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath)) + if err != nil { + return fmt.Errorf("failed to convert path to UTF16: %w", err) + } + + creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT + + err = windows.CreateProcessAsUser( + primaryToken, + nil, + cmdLine, + nil, + nil, + false, + creationFlags, + nil, + nil, + &si, + &pi, + ) + if err != nil { + return fmt.Errorf("CreateProcessAsUser failed: %w", err) + } + + // Close handles + if err := windows.CloseHandle(pi.Process); err != nil { + log.Warnf("failed to close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Warnf("failed to close thread handle: %v", err) + } + + log.Infof("netbird-ui started successfully in session %d", sessionID) + return nil +} + +func urlWithVersionArch(it Type, version string) string { + var url string + if it == TypeExe { + url = exeDownloadURL + } else { + url = msiDownloadURL + } + url = strings.ReplaceAll(url, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updatemanager/installer/log.go new file mode 100644 index 000000000..8b60dba28 --- /dev/null +++ b/client/internal/updatemanager/installer/log.go @@ -0,0 +1,5 @@ +package installer + +const ( + LogFile = "installer.log" +) diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updatemanager/installer/procattr_darwin.go new file mode 100644 index 000000000..56f2018bb --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_darwin.go @@ -0,0 +1,15 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run in a new session, +// making it independent of the parent daemon process. This ensures the updater +// survives when the daemon is stopped during the pkg installation. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } +} diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updatemanager/installer/procattr_windows.go new file mode 100644 index 000000000..29a8a2de0 --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_windows.go @@ -0,0 +1,14 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run detached from the parent, +// making it independent of the parent daemon process. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS + } +} diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updatemanager/installer/repourl_dev.go new file mode 100644 index 000000000..088821ad3 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_dev.go @@ -0,0 +1,7 @@ +//go:build devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo" +) diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updatemanager/installer/repourl_prod.go new file mode 100644 index 000000000..abddc62c1 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_prod.go @@ -0,0 +1,7 @@ +//go:build !devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures" +) diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updatemanager/installer/result.go new file mode 100644 index 000000000..03d08d527 --- /dev/null +++ b/client/internal/updatemanager/installer/result.go @@ -0,0 +1,230 @@ +package installer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" +) + +const ( + resultFile = "result.json" +) + +type Result struct { + Success bool + Error string + ExecutedAt time.Time +} + +// ResultHandler handles reading and writing update results +type ResultHandler struct { + resultFile string +} + +// NewResultHandler creates a new communicator with the given directory path +// The result file will be created as "result.json" in the specified directory +func NewResultHandler(installerDir string) *ResultHandler { + // Create it if it doesn't exist + // do not care if already exists + _ = os.MkdirAll(installerDir, 0o700) + + rh := &ResultHandler{ + resultFile: filepath.Join(installerDir, resultFile), + } + return rh +} + +func (rh *ResultHandler) GetErrorResultReason() string { + result, err := rh.tryReadResult() + if err == nil && !result.Success { + return result.Error + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return "" +} + +func (rh *ResultHandler) WriteSuccess() error { + result := Result{ + Success: true, + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) WriteErr(errReason error) error { + result := Result{ + Success: false, + Error: errReason.Error(), + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) { + log.Infof("start watching result: %s", rh.resultFile) + + // Check if file already exists (updater finished before we started watching) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + dir := filepath.Dir(rh.resultFile) + + if err := rh.waitForDirectory(ctx, dir); err != nil { + return Result{}, err + } + + return rh.watchForResultFile(ctx, dir) +} + +func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error { + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if info, err := os.Stat(dir); err == nil && info.IsDir() { + return nil + } + } + } +} + +func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Error(err) + return Result{}, err + } + + defer func() { + if err := watcher.Close(); err != nil { + log.Warnf("failed to close watcher: %v", err) + } + }() + + if err := watcher.Add(dir); err != nil { + return Result{}, fmt.Errorf("failed to watch directory: %v", err) + } + + // Check again after setting up watcher to avoid race condition + // (file could have been created between initial check and watcher setup) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + for { + select { + case <-ctx.Done(): + return Result{}, ctx.Err() + case event, ok := <-watcher.Events: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + + if result, done := rh.handleWatchEvent(event); done { + return result, nil + } + case err, ok := <-watcher.Errors: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + return Result{}, fmt.Errorf("watcher error: %w", err) + } + } +} + +func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) { + if event.Name != rh.resultFile { + return Result{}, false + } + + if event.Has(fsnotify.Create) { + result, err := rh.tryReadResult() + if err != nil { + log.Debugf("error while reading result: %v", err) + return result, true + } + log.Infof("installer result: %v", result) + return result, true + } + + return Result{}, false +} + +// Write writes the update result to a file for the UI to read +func (rh *ResultHandler) write(result Result) error { + log.Infof("write out installer result to: %s", rh.resultFile) + // Ensure directory exists + dir := filepath.Dir(rh.resultFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + log.Errorf("failed to create directory %s: %v", dir, err) + return err + } + + data, err := json.Marshal(result) + if err != nil { + return err + } + + // Write to a temporary file first, then rename for atomic operation + tmpPath := rh.resultFile + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + log.Errorf("failed to create temp file: %s", err) + return err + } + + // Atomic rename + if err := os.Rename(tmpPath, rh.resultFile); err != nil { + if cleanupErr := os.Remove(tmpPath); cleanupErr != nil { + log.Warnf("Failed to remove temp result file: %v", err) + } + return err + } + + return nil +} + +func (rh *ResultHandler) cleanup() error { + err := os.Remove(rh.resultFile) + if err != nil && !os.IsNotExist(err) { + return err + } + log.Debugf("delete installer result file: %s", rh.resultFile) + return nil +} + +// tryReadResult attempts to read and validate the result file +func (rh *ResultHandler) tryReadResult() (Result, error) { + data, err := os.ReadFile(rh.resultFile) + if err != nil { + return Result{}, err + } + + var result Result + if err := json.Unmarshal(data, &result); err != nil { + return Result{}, fmt.Errorf("invalid result format: %w", err) + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return result, nil +} diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updatemanager/installer/types.go new file mode 100644 index 000000000..656d84f88 --- /dev/null +++ b/client/internal/updatemanager/installer/types.go @@ -0,0 +1,14 @@ +package installer + +type Type struct { + name string + downloadable bool +} + +func (t Type) String() string { + return t.name +} + +func (t Type) Downloadable() bool { + return t.downloadable +} diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updatemanager/installer/types_darwin.go new file mode 100644 index 000000000..95a0cb737 --- /dev/null +++ b/client/internal/updatemanager/installer/types_darwin.go @@ -0,0 +1,22 @@ +package installer + +import ( + "context" + "os/exec" +) + +var ( + TypeHomebrew = Type{name: "Homebrew", downloadable: false} + TypePKG = Type{name: "pkg", downloadable: true} +) + +func TypeOfInstaller(ctx context.Context) Type { + cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client") + _, err := cmd.Output() + if err != nil && cmd.ProcessState.ExitCode() == 1 { + // Not installed using pkg file, thus installed using Homebrew + + return TypeHomebrew + } + return TypePKG +} diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updatemanager/installer/types_windows.go new file mode 100644 index 000000000..d4e5d83bd --- /dev/null +++ b/client/internal/updatemanager/installer/types_windows.go @@ -0,0 +1,51 @@ +package installer + +import ( + "context" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows/registry" +) + +const ( + uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` + uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` +) + +var ( + TypeExe = Type{name: "EXE", downloadable: true} + TypeMSI = Type{name: "MSI", downloadable: true} +) + +func TypeOfInstaller(_ context.Context) Type { + paths := []string{uninstallKeyPath64, uninstallKeyPath32} + + for _, path := range paths { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + continue + } + + if err := k.Close(); err != nil { + log.Warnf("Error closing registry key: %v", err) + } + return TypeExe + + } + + log.Debug("No registry entry found for Netbird, assuming MSI installation") + return TypeMSI +} + +func typeByFileExtension(filePath string) (Type, error) { + switch { + case strings.HasSuffix(strings.ToLower(filePath), ".exe"): + return TypeExe, nil + case strings.HasSuffix(strings.ToLower(filePath), ".msi"): + return TypeMSI, nil + default: + return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath) + } +} diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go new file mode 100644 index 000000000..eae11de56 --- /dev/null +++ b/client/internal/updatemanager/manager.go @@ -0,0 +1,374 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "time" + + v "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + cProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/version" +) + +const ( + latestVersion = "latest" + // this version will be ignored + developmentVersion = "development" +) + +var errNoUpdateState = errors.New("no update state found") + +type UpdateState struct { + PreUpdateVersion string + TargetVersion string +} + +func (u UpdateState) Name() string { + return "autoUpdate" +} + +type Manager struct { + statusRecorder *peer.Status + stateManager *statemanager.Manager + + lastTrigger time.Time + mgmUpdateChan chan struct{} + updateChannel chan struct{} + currentVersion string + update UpdateInterface + wg sync.WaitGroup + + cancel context.CancelFunc + + expectedVersion *v.Version + updateToLatestVersion bool + + // updateMutex protect update and expectedVersion fields + updateMutex sync.Mutex + + triggerUpdateFn func(context.Context, string) error +} + +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + if runtime.GOOS == "darwin" { + isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Home Brew installation") + return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") + } + } + return newManager(statusRecorder, stateManager) +} + +func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + manager := &Manager{ + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + } + manager.triggerUpdateFn = manager.triggerUpdate + + stateManager.RegisterState(&UpdateState{}) + + return manager, nil +} + +// CheckUpdateSuccess checks if the update was successful and send a notification. +// It works without to start the update manager. +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + reason := m.lastResultErrReason() + if reason != "" { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %s", reason), + nil, + ) + } + + updateState, err := m.loadAndDeleteUpdateState(ctx) + if err != nil { + if errors.Is(err, errNoUpdateState) { + return + } + log.Errorf("failed to load update state: %v", err) + return + } + + log.Debugf("auto-update state loaded, %v", *updateState) + + if updateState.TargetVersion == m.currentVersion { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "Auto-update completed", + fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion), + nil, + ) + return + } +} + +func (m *Manager) Start(ctx context.Context) { + if m.cancel != nil { + log.Errorf("Manager already started") + return + } + + m.update.SetDaemonVersion(m.currentVersion) + m.update.SetOnUpdateListener(func() { + select { + case m.updateChannel <- struct{}{}: + default: + } + }) + go m.update.StartFetcher() + + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + + m.wg.Add(1) + go m.updateLoop(ctx) +} + +func (m *Manager) SetVersion(expectedVersion string) { + log.Infof("set expected agent version for upgrade: %s", expectedVersion) + if m.cancel == nil { + log.Errorf("manager not started") + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + + if expectedVersion == "" { + log.Errorf("empty expected version provided") + m.expectedVersion = nil + m.updateToLatestVersion = false + return + } + + if expectedVersion == latestVersion { + m.updateToLatestVersion = true + m.expectedVersion = nil + } else { + expectedSemVer, err := v.NewVersion(expectedVersion) + if err != nil { + log.Errorf("error parsing version: %v", err) + return + } + if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) { + return + } + m.expectedVersion = expectedSemVer + m.updateToLatestVersion = false + } + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) Stop() { + if m.cancel == nil { + return + } + + m.cancel() + m.updateMutex.Lock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } + m.updateMutex.Unlock() + + m.wg.Wait() +} + +func (m *Manager) onContextCancel() { + if m.cancel == nil { + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } +} + +func (m *Manager) updateLoop(ctx context.Context) { + defer m.wg.Done() + + for { + select { + case <-ctx.Done(): + m.onContextCancel() + return + case <-m.mgmUpdateChan: + case <-m.updateChannel: + log.Infof("fetched new version info") + } + + m.handleUpdate(ctx) + } +} + +func (m *Manager) handleUpdate(ctx context.Context) { + var updateVersion *v.Version + + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + + expectedVersion := m.expectedVersion + useLatest := m.updateToLatestVersion + curLatestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + switch { + // Resolve "latest" to actual version + case useLatest: + if curLatestVersion == nil { + log.Tracef("latest version not fetched yet") + return + } + updateVersion = curLatestVersion + // Update to specific version + case expectedVersion != nil: + updateVersion = expectedVersion + default: + log.Debugf("no expected version information set") + return + } + + log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) + if !m.shouldUpdate(updateVersion) { + return + } + + m.lastTrigger = time.Now() + log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Automatically updating client", + "Your client version is older than auto-update version set in Management, updating client now.", + nil, + ) + + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "", + "", + map[string]string{"progress_window": "show", "version": updateVersion.String()}, + ) + + updateState := UpdateState{ + PreUpdateVersion: m.currentVersion, + TargetVersion: updateVersion.String(), + } + + if err := m.stateManager.UpdateState(updateState); err != nil { + log.Warnf("failed to update state: %v", err) + } else { + if err = m.stateManager.PersistState(ctx); err != nil { + log.Warnf("failed to persist state: %v", err) + } + } + + if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { + log.Errorf("Error triggering auto-update: %v", err) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %v", err), + nil, + ) + } +} + +// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. +// Returns nil if no state exists. +func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) { + stateType := &UpdateState{} + + m.stateManager.RegisterState(stateType) + if err := m.stateManager.LoadState(stateType); err != nil { + return nil, fmt.Errorf("load state: %w", err) + } + + state := m.stateManager.GetState(stateType) + if state == nil { + return nil, errNoUpdateState + } + + updateState, ok := state.(*UpdateState) + if !ok { + return nil, fmt.Errorf("failed to cast state to UpdateState") + } + + if err := m.stateManager.DeleteState(updateState); err != nil { + return nil, fmt.Errorf("delete state: %w", err) + } + + if err := m.stateManager.PersistState(ctx); err != nil { + return nil, fmt.Errorf("persist state: %w", err) + } + + return updateState, nil +} + +func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { + if m.currentVersion == developmentVersion { + log.Debugf("skipping auto-update, running development version") + return false + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil { + log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err) + return false + } + if currentVersion.GreaterThanOrEqual(updateVersion) { + log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion) + return false + } + + if time.Since(m.lastTrigger) < 5*time.Minute { + log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + return false + } + + return true +} + +func (m *Manager) lastResultErrReason() string { + inst := installer.New() + result := installer.NewResultHandler(inst.TempDir()) + return result.GetErrorResultReason() +} + +func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { + inst := installer.New() + return inst.RunInstallation(ctx, targetVersion) +} diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go new file mode 100644 index 000000000..20ddec10d --- /dev/null +++ b/client/internal/updatemanager/manager_test.go @@ -0,0 +1,214 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (v versionUpdateMock) StopWatch() {} + +func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + v.onUpdate = updateFn +} + +func (v versionUpdateMock) LatestVersion() *v.Version { + return v.latestVersion +} + +func (v versionUpdateMock) StartFetcher() {} + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should only trigger update once due to time between triggers being < 5 Minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: false, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = mockUpdate + + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion("latest") + var triggeredInit bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) + } + triggeredInit = true + case <-time.After(10 * time.Millisecond): + triggeredInit = false + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + var triggeredLater bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } + triggeredLater = true + case <-time.After(10 * time.Millisecond): + triggeredLater = false + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Update to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Update to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Update to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Update to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Update to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion(c.expectedVersion) + + var updateTriggered bool + select { + case targetVersion := <-targetVersionChan: + if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) + } + updateTriggered = true + case <-time.After(10 * time.Millisecond): + updateTriggered = false + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go new file mode 100644 index 000000000..4e87c2d77 --- /dev/null +++ b/client/internal/updatemanager/manager_unsupported.go @@ -0,0 +1,39 @@ +//go:build !windows && !darwin + +package updatemanager + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// Manager is a no-op stub for unsupported platforms +type Manager struct{} + +// NewManager returns a no-op manager for unsupported platforms +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + return nil, fmt.Errorf("update manager is not supported on this platform") +} + +// CheckUpdateSuccess is a no-op on unsupported platforms +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + // no-op +} + +// Start is a no-op on unsupported platforms +func (m *Manager) Start(ctx context.Context) { + // no-op +} + +// SetVersion is a no-op on unsupported platforms +func (m *Manager) SetVersion(expectedVersion string) { + // no-op +} + +// Stop is a no-op on unsupported platforms +func (m *Manager) Stop() { + // no-op +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updatemanager/reposign/artifact.go new file mode 100644 index 000000000..3d4fe9c74 --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact.go @@ -0,0 +1,302 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/blake2s" +) + +const ( + tagArtifactPrivate = "ARTIFACT PRIVATE KEY" + tagArtifactPublic = "ARTIFACT PUBLIC KEY" + + maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour + maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour +) + +// ArtifactHash wraps a hash.Hash and counts bytes written +type ArtifactHash struct { + hash.Hash +} + +// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s +func NewArtifactHash() *ArtifactHash { + h, err := blake2s.New256(nil) + if err != nil { + panic(err) // Should never happen with nil Key + } + return &ArtifactHash{Hash: h} +} + +func (ah *ArtifactHash) Write(b []byte) (int, error) { + return ah.Hash.Write(b) +} + +// ArtifactKey is a signing Key used to sign artifacts +type ArtifactKey struct { + PrivateKey +} + +func (k ArtifactKey) String() string { + return fmt.Sprintf( + "ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) { + // Verify root key is still valid + if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) { + return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339)) + } + + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err) + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + ak := &ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(ak.PrivateKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + // Sign the public key with the root key + signature, err := SignArtifactKey(*rootKey, pubPEM) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err) + } + + return ak, privPEM, pubPEM, signature, nil +} + +func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate) + if err != nil { + return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err) + } + return ArtifactKey{pk}, nil +} + +func ParseArtifactPubKey(data []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(data, tagArtifactPublic) + return pk, err +} + +func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) { + if len(keys) == 0 { + return nil, nil, errors.New("no keys to bundle") + } + + // Create bundle by concatenating PEM-encoded keys + var pubBundle []byte + + for _, pk := range keys { + // Marshal PublicKey struct to JSON + pubJSON, err := json.Marshal(pk) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + pubBundle = append(pubBundle, pubPEM...) + } + + // Sign the entire bundle with the root key + signature, err := SignArtifactKey(*rootKey, pubBundle) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err) + } + + return pubBundle, signature, nil +} + +func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) { + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge { + err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + + // Reconstruct the signed message: artifact_key_data || timestamp + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("failed to verify signature of artifact keys") + } + + pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic) + if err != nil { + log.Debugf("failed to parse public keys: %s", err) + return nil, err + } + + validKeys := make([]PublicKey, 0, len(pubKeys)) + for _, pubKey := range pubKeys { + // Filter out expired keys + if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) { + log.Debugf("Key %s is expired at %v (current time %v)", + pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now) + continue + } + + if revocationList != nil { + if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked { + log.Debugf("Key %s is revoked as of %v (created %v)", + pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt) + continue + } + } + validKeys = append(validKeys, pubKey) + } + + if len(validKeys) == 0 { + log.Debugf("no valid public keys found for artifact keys") + return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys)) + } + + return validKeys, nil +} + +func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error { + // Validate signature timestamp + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("failed to verify signature of artifact: %s", err) + return err + } + if now.Sub(signature.Timestamp) > maxArtifactSignatureAge { + return fmt.Errorf("artifact signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return fmt.Errorf("failed to hash artifact: %w", err) + } + hash := h.Sum(nil) + + // Reconstruct the signed message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + // Find matching Key and verify + for _, keyInfo := range artifactPubKeys { + if keyInfo.Metadata.ID == signature.KeyID { + // Check Key expiration + if !keyInfo.Metadata.ExpiresAt.IsZero() && + signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) { + return fmt.Errorf("signing Key %s expired at %v, signature from %v", + signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp) + } + + if ed25519.Verify(keyInfo.Key, msg, signature.Signature) { + log.Debugf("artifact verified successfully with Key: %s", signature.KeyID) + return nil + } + return fmt.Errorf("signature verification failed for Key %s", signature.KeyID) + } + } + + return fmt.Errorf("no signing Key found with ID %s", signature.KeyID) +} + +func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) { + if len(data) == 0 { // Check happens too late + return nil, fmt.Errorf("artifact length must be positive, got %d", len(data)) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return nil, fmt.Errorf("failed to write artifact hash: %w", err) + } + + timestamp := time.Now().UTC() + + if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) { + return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt) + } + + hash := h.Sum(nil) + + // Create message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(artifactKey.Key, msg) + + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: artifactKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updatemanager/reposign/artifact_test.go new file mode 100644 index 000000000..8865e2d0a --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact_test.go @@ -0,0 +1,1080 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactHash + +func TestNewArtifactHash(t *testing.T) { + h := NewArtifactHash() + assert.NotNil(t, h) + assert.NotNil(t, h.Hash) +} + +func TestArtifactHash_Write(t *testing.T) { + h := NewArtifactHash() + + data := []byte("test data") + n, err := h.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + + hash := h.Sum(nil) + assert.NotEmpty(t, hash) + assert.Equal(t, 32, len(hash)) // BLAKE2s-256 +} + +func TestArtifactHash_Deterministic(t *testing.T) { + data := []byte("test data") + + h1 := NewArtifactHash() + if _, err := h1.Write(data); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write(data); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.Equal(t, hash1, hash2) +} + +func TestArtifactHash_DifferentData(t *testing.T) { + h1 := NewArtifactHash() + if _, err := h1.Write([]byte("data1")); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write([]byte("data2")); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.NotEqual(t, hash1, hash2) +} + +// Test ArtifactKey.String() + +func TestArtifactKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + + ak := ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := ak.String() + assert.Contains(t, str, "ArtifactKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2025-01-15") +} + +// Test GenerateArtifactKey + +func TestGenerateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + }, + } + + // Generate artifact key + ak, privPEM, pubPEM, signature, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + assert.NotEmpty(t, signature) + + // Verify expiration + assert.True(t, ak.Metadata.ExpiresAt.After(time.Now())) + assert.True(t, ak.Metadata.ExpiresAt.Before(time.Now().Add(31*24*time.Hour))) +} + +func TestGenerateArtifactKey_ExpiredRoot(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create expired root key + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().Add(-2 * 365 * 24 * time.Hour).UTC(), + ExpiresAt: time.Now().Add(-1 * time.Hour).UTC(), // Expired + }, + }, + } + + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestGenerateArtifactKey_NoExpiration(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Root key with no expiration + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + ak, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) +} + +// Test ParseArtifactKey + +func TestParseArtifactKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Parse it back + parsed, err := ParseArtifactKey(privPEM) + require.NoError(t, err) + + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactKey_InvalidPEM(t *testing.T) { + _, err := ParseArtifactKey([]byte("invalid pem")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseArtifactKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a root key (wrong type) + rootKey := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + privJSON, err := json.Marshal(rootKey.PrivateKey) + require.NoError(t, err) + + privPEM := encodePrivateKey(privJSON, tagRootPrivate) + + _, err = ParseArtifactKey(privPEM) + assert.Error(t, err) +} + +// Test ParseArtifactPubKey + +func TestParseArtifactPubKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + parsed, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactPubKey_Invalid(t *testing.T) { + _, err := ParseArtifactPubKey([]byte("invalid")) + assert.Error(t, err) +} + +// Test BundleArtifactKeys + +func TestBundleArtifactKeys_Single(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, signature, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) +} + +func TestBundleArtifactKeys_Multiple(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate 3 artifact keys + var pubKeys []PublicKey + for i := 0; i < 3; i++ { + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + pubKeys = append(pubKeys, pubKey) + } + + bundle, signature, err := BundleArtifactKeys(rootKey, pubKeys) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) + + // Verify we can parse the bundle + parsed, err := parsePublicKeyBundle(bundle, tagArtifactPublic) + require.NoError(t, err) + assert.Len(t, parsed, 3) +} + +func TestBundleArtifactKeys_Empty(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, err = BundleArtifactKeys(rootKey, []PublicKey{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys") +} + +// Test ValidateArtifactKeys + +func TestSingleValidateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, sigData, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + sig, _ := ParseSignature(sigData) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, pubPEM, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + // Bundle and sign + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_FutureTimestamp(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifactKeys_TooOld(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifactKeys_InvalidSignature(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, _, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, invalidSig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to verify") +} + +func TestValidateArtifactKeys_WithRevocation(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Bundle both keys + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Create revocation list with first key revoked + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey1.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + // Validate - should only return second key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, pubKey2.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_AllRevoked(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Revoke the key + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + assert.Error(t, err) + assert.Contains(t, err.Error(), "revoked") +} + +// Test ValidateArtifact + +func TestValidateArtifact_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign some data + data := []byte("test artifact data") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Get public key for validation + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Validate + err = ValidateArtifact([]PublicKey{artifactPubKey}, data, *sig) + assert.NoError(t, err) +} + +func TestValidateArtifact_FutureTimestamp(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifact_TooOld(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifact_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for key to expire + time.Sleep(10 * time.Millisecond) + + // Try to sign - should succeed but with old timestamp + data := []byte("test data") + sigData, err := SignData(*artifactKey, data) + require.Error(t, err) // Key is expired, so signing should fail + assert.Contains(t, err.Error(), "expired") + assert.Nil(t, sigData) +} + +func TestValidateArtifact_WrongKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + artifactKey1, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactKey2, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign with key1 + data := []byte("test data") + sigData, err := SignData(*artifactKey1, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Try to validate with key2 only + artifactPubKey2 := PublicKey{ + Key: artifactKey2.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey2.Metadata, + } + + err = ValidateArtifact([]PublicKey{artifactPubKey2}, data, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestValidateArtifact_TamperedData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + sigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Try to validate with tampered data + tamperedData := []byte("tampered data") + err = ValidateArtifact([]PublicKey{artifactPubKey}, tamperedData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateArtifactKeys_TwoKeysOneExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two keys where one is expired + // Should return only the valid key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with very short expiration + _, _, expiredPubPEM, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + expiredPubKey, err := ParseArtifactPubKey(expiredPubPEM) + require.NoError(t, err) + + // Wait for first key to expire + time.Sleep(10 * time.Millisecond) + + // Generate second key with normal expiration + _, _, validPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + validPubKey, err := ParseArtifactPubKey(validPubPEM) + require.NoError(t, err) + + // Bundle both keys together + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{expiredPubKey, validPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should return only the valid key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, validPubKey.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_TwoKeysBothExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two expired keys + // Should fail because no valid keys remain + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + // Generate second key with very short expiration + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Wait for expire + time.Sleep(10 * time.Millisecond) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should fail because all keys are expired + keys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + assert.NoError(t, err) + assert.Len(t, keys, 1) +} + +// Test SignData + +func TestSignData_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestSignData_EmptyData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + _, err = SignData(*artifactKey, []byte{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") +} + +func TestSignData_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to sign with expired key + _, err = SignData(*artifactKey, []byte("data")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +// Integration test + +func TestArtifact_FullWorkflow(t *testing.T) { + // Step 1: Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 2: Generate artifact key + artifactKey, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Step 3: Create and validate key bundle + artifactPubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, bundleSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(bundleSig) + require.NoError(t, err) + + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + + // Step 4: Sign artifact data + artifactData := []byte("This is my artifact data that needs to be signed") + artifactSig, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 5: Validate artifact + parsedSig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + err = ValidateArtifact(validKeys, artifactData, *parsedSig) + assert.NoError(t, err) +} + +// Helper function for tests +func encodePrivateKey(jsonData []byte, typeTag string) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: typeTag, + Bytes: jsonData, + }) +} diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updatemanager/reposign/certs/root-pub.pem new file mode 100644 index 000000000..e7c2fd2c0 --- /dev/null +++ b/client/internal/updatemanager/reposign/certs/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn +V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz +X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updatemanager/reposign/certsdev/root-pub.pem new file mode 100644 index 000000000..f7145477b --- /dev/null +++ b/client/internal/updatemanager/reposign/certsdev/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr +ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz +X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updatemanager/reposign/doc.go new file mode 100644 index 000000000..660b9d11d --- /dev/null +++ b/client/internal/updatemanager/reposign/doc.go @@ -0,0 +1,174 @@ +// Package reposign implements a cryptographic signing and verification system +// for NetBird software update artifacts. It provides a hierarchical key +// management system with support for key rotation, revocation, and secure +// artifact distribution. +// +// # Architecture +// +// The package uses a two-tier key hierarchy: +// +// - Root Keys: Long-lived keys that sign artifact keys. These are embedded +// in the client binary and establish the root of trust. Root keys should +// be kept offline and highly secured. +// +// - Artifact Keys: Short-lived keys that sign release artifacts (binaries, +// packages, etc.). These are rotated regularly and can be revoked if +// compromised. Artifact keys are signed by root keys and distributed via +// a public repository. +// +// This separation allows for operational flexibility: artifact keys can be +// rotated frequently without requiring client updates, while root keys remain +// stable and embedded in the software. +// +// # Cryptographic Primitives +// +// The package uses strong, modern cryptographic algorithms: +// - Ed25519: Fast, secure digital signatures (no timing attacks) +// - BLAKE2s-256: Fast cryptographic hash for artifacts +// - SHA-256: Key ID generation +// - JSON: Structured key and signature serialization +// - PEM: Standard key encoding format +// +// # Security Features +// +// Timestamp Binding: +// - All signatures include cryptographically-bound timestamps +// - Prevents replay attacks and enforces signature freshness +// - Clock skew tolerance: 5 minutes +// +// Key Expiration: +// - All keys have expiration times +// - Expired keys are automatically rejected +// - Signing with an expired key fails immediately +// +// Key Revocation: +// - Compromised keys can be revoked via a signed revocation list +// - Revocation list is checked during artifact validation +// - Revoked keys are filtered out before artifact verification +// +// # File Structure +// +// The package expects the following file layout in the key repository: +// +// signrepo/ +// artifact-key-pub.pem # Bundle of artifact public keys +// artifact-key-pub.pem.sig # Root signature of the bundle +// revocation-list.json # List of revoked key IDs +// revocation-list.json.sig # Root signature of revocation list +// +// And in the artifacts repository: +// +// releases/ +// v0.28.0/ +// netbird-linux-amd64 +// netbird-linux-amd64.sig # Artifact signature +// netbird-darwin-amd64 +// netbird-darwin-amd64.sig +// ... +// +// # Embedded Root Keys +// +// Root public keys are embedded in the client binary at compile time: +// - Production keys: certs/ directory +// - Development keys: certsdev/ directory +// +// The build tag determines which keys are embedded: +// - Production builds: //go:build !devartifactsign +// - Development builds: //go:build devartifactsign +// +// This ensures that development artifacts cannot be verified using production +// keys and vice versa. +// +// # Key Rotation Strategies +// +// Root Key Rotation: +// +// Root keys can be rotated without breaking existing clients by leveraging +// the multi-key verification system. The loadEmbeddedPublicKeys function +// reads ALL files from the certs/ directory and accepts signatures from ANY +// of the embedded root keys. +// +// To rotate root keys: +// +// 1. Generate a new root key pair: +// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) +// +// 2. Add the new public key to the certs/ directory as a new file: +// certs/ +// root-pub-2024.pem # Old key (keep this!) +// root-pub-2025.pem # New key (add this) +// +// 3. Build new client versions with both keys embedded. The verification +// will accept signatures from either key. +// +// 4. Start signing new artifact keys with the new root key. Old clients +// with only the old root key will reject these, but new clients with +// both keys will accept them. +// +// Each file in certs/ can contain a single key or a bundle of keys (multiple +// PEM blocks). The system will parse all keys from all files and use them +// for verification. This provides maximum flexibility for key management. +// +// Important: Never remove all old root keys at once. Always maintain at least +// one overlapping key between releases to ensure smooth transitions. +// +// Artifact Key Rotation: +// +// Artifact keys should be rotated regularly (e.g., every 90 days) using the +// bundling mechanism. The BundleArtifactKeys function allows multiple artifact +// keys to be bundled together in a single signed package, and ValidateArtifact +// will accept signatures from ANY key in the bundle. +// +// To rotate artifact keys smoothly: +// +// 1. Generate a new artifact key while keeping the old one: +// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour) +// // Keep oldPubPEM and oldKey available +// +// 2. Create a bundle containing both old and new public keys +// +// 3. Upload the bundle and its signature to the key repository: +// signrepo/artifact-key-pub.pem # Contains both keys +// signrepo/artifact-key-pub.pem.sig # Root signature +// +// 4. Start signing new releases with the NEW key, but keep the bundle +// unchanged. Clients will download the bundle (containing both keys) +// and accept signatures from either key. +// +// Key bundle validation workflow: +// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig +// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key +// 3. ValidateArtifactKeys parses all public keys from the bundle +// 4. ValidateArtifactKeys filters out expired or revoked keys +// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds +// +// This multi-key acceptance model enables overlapping validity periods and +// smooth transitions without client update requirements. +// +// # Best Practices +// +// Root Key Management: +// - Generate root keys offline on an air-gapped machine +// - Store root private keys in hardware security modules (HSM) if possible +// - Use separate root keys for production and development +// - Rotate root keys infrequently (e.g., every 5-10 years) +// - Plan for root key rotation: embed multiple root public keys +// +// Artifact Key Management: +// - Rotate artifact keys regularly (e.g., every 90 days) +// - Use separate artifact keys for different release channels if needed +// - Revoke keys immediately upon suspected compromise +// - Bundle multiple artifact keys to enable smooth rotation +// +// Signing Process: +// - Sign artifacts in a secure CI/CD environment +// - Never commit private keys to version control +// - Use environment variables or secret management for keys +// - Verify signatures immediately after signing +// +// Distribution: +// - Serve keys and revocation lists from a reliable CDN +// - Use HTTPS for all key and artifact downloads +// - Monitor download failures and signature verification failures +// - Keep revocation list up to date +package reposign diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updatemanager/reposign/embed_dev.go new file mode 100644 index 000000000..ef8f77373 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_dev.go @@ -0,0 +1,10 @@ +//go:build devartifactsign + +package reposign + +import "embed" + +//go:embed certsdev +var embeddedCerts embed.FS + +const embeddedCertsDir = "certsdev" diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updatemanager/reposign/embed_prod.go new file mode 100644 index 000000000..91530e5f4 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_prod.go @@ -0,0 +1,10 @@ +//go:build !devartifactsign + +package reposign + +import "embed" + +//go:embed certs +var embeddedCerts embed.FS + +const embeddedCertsDir = "certs" diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updatemanager/reposign/key.go new file mode 100644 index 000000000..bedfef70d --- /dev/null +++ b/client/internal/updatemanager/reposign/key.go @@ -0,0 +1,171 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "time" +) + +const ( + maxClockSkew = 5 * time.Minute +) + +// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key) +type KeyID [8]byte + +// computeKeyID generates a unique ID from a public Key +func computeKeyID(pub ed25519.PublicKey) KeyID { + h := sha256.Sum256(pub) + var id KeyID + copy(id[:], h[:8]) + return id +} + +// MarshalJSON implements json.Marshaler for KeyID +func (k KeyID) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +// UnmarshalJSON implements json.Unmarshaler for KeyID +func (k *KeyID) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + parsed, err := ParseKeyID(s) + if err != nil { + return err + } + + *k = parsed + return nil +} + +// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID. +func ParseKeyID(s string) (KeyID, error) { + var id KeyID + if len(s) != 16 { + return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s)) + } + + b, err := hex.DecodeString(s) + if err != nil { + return id, fmt.Errorf("failed to decode KeyID: %w", err) + } + + copy(id[:], b) + return id, nil +} + +func (k KeyID) String() string { + return fmt.Sprintf("%x", k[:]) +} + +// KeyMetadata contains versioning and lifecycle information for a Key +type KeyMetadata struct { + ID KeyID `json:"id"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration +} + +// PublicKey wraps a public Key with its Metadata +type PublicKey struct { + Key ed25519.PublicKey + Metadata KeyMetadata +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) { + var keys []PublicKey + for len(bundle) > 0 { + keyInfo, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, keyInfo) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no keys found in bundle") + } + return keys, nil +} + +func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) { + b, rest := pem.Decode(data) + if b == nil { + return PublicKey{}, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pub PublicKey + if err := json.Unmarshal(b.Bytes, &pub); err != nil { + return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err) + } + + // Validate key length + if len(pub.Key) != ed25519.PublicKeySize { + return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d", + ed25519.PublicKeySize, len(pub.Key)) + } + + // Always recompute ID to ensure integrity + pub.Metadata.ID = computeKeyID(pub.Key) + + return pub, rest, nil +} + +type PrivateKey struct { + Key ed25519.PrivateKey + Metadata KeyMetadata +} + +func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return PrivateKey{}, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return PrivateKey{}, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pk PrivateKey + if err := json.Unmarshal(b.Bytes, &pk); err != nil { + return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err) + } + + // Validate key length + if len(pk.Key) != ed25519.PrivateKeySize { + return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d", + ed25519.PrivateKeySize, len(pk.Key)) + } + + return pk, nil +} + +func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool { + // Verify with root keys + var rootKeys []ed25519.PublicKey + for _, r := range publicRootKeys { + rootKeys = append(rootKeys, r.Key) + } + + for _, k := range rootKeys { + if ed25519.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updatemanager/reposign/key_test.go new file mode 100644 index 000000000..f8e1676fb --- /dev/null +++ b/client/internal/updatemanager/reposign/key_test.go @@ -0,0 +1,636 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test KeyID functions + +func TestComputeKeyID(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + + // Verify it's the first 8 bytes of SHA-256 + h := sha256.Sum256(pub) + expectedID := KeyID{} + copy(expectedID[:], h[:8]) + + assert.Equal(t, expectedID, keyID) +} + +func TestComputeKeyID_Deterministic(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Computing KeyID multiple times should give the same result + keyID1 := computeKeyID(pub) + keyID2 := computeKeyID(pub) + + assert.Equal(t, keyID1, keyID2) +} + +func TestComputeKeyID_DifferentKeys(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + // Different keys should produce different IDs + assert.NotEqual(t, keyID1, keyID2) +} + +func TestParseKeyID_Valid(t *testing.T) { + hexStr := "0123456789abcdef" + + keyID, err := ParseKeyID(hexStr) + require.NoError(t, err) + + expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + assert.Equal(t, expected, keyID) +} + +func TestParseKeyID_InvalidLength(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"too short", "01234567"}, + {"too long", "0123456789abcdef00"}, + {"empty", ""}, + {"odd length", "0123456789abcde"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseKeyID(tt.input) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid KeyID length") + }) + } +} + +func TestParseKeyID_InvalidHex(t *testing.T) { + invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex + + _, err := ParseKeyID(invalidHex) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode KeyID") +} + +func TestKeyID_String(t *testing.T) { + keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + + str := keyID.String() + assert.Equal(t, "0123456789abcdef", str) +} + +func TestKeyID_RoundTrip(t *testing.T) { + original := "fedcba9876543210" + + keyID, err := ParseKeyID(original) + require.NoError(t, err) + + result := keyID.String() + assert.Equal(t, original, result) +} + +func TestKeyID_ZeroValue(t *testing.T) { + keyID := KeyID{} + str := keyID.String() + assert.Equal(t, "0000000000000000", str) +} + +// Test KeyMetadata + +func TestKeyMetadata_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC), + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, metadata.ID, decoded.ID) + assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix()) + assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix()) +} + +func TestKeyMetadata_NoExpiration(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Time{}, // Zero value = no expiration + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.True(t, decoded.ExpiresAt.IsZero()) +} + +// Test PublicKey + +func TestPublicKey_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + var decoded PublicKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, pubKey.Key, decoded.Key) + assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePublicKey + +func TestParsePublicKey_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + } + + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + + // Marshal to JSON + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode to PEM + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse it back + parsed, rest, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + assert.Empty(t, rest) + assert.Equal(t, pub, parsed.Key) + assert.Equal(t, metadata.ID, parsed.Metadata.ID) +} + +func TestParsePublicKey_InvalidPEM(t *testing.T) { + invalidPEM := []byte("not a PEM") + + _, _, err := parsePublicKey(invalidPEM, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePublicKey_WrongType(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode with wrong type + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePublicKey_InvalidJSON(t *testing.T) { + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: []byte("invalid json"), + }) + + _, _, err := parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParsePublicKey_InvalidKeySize(t *testing.T) { + // Create a public key with wrong size + pubKey := PublicKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 public key size") +} + +func TestParsePublicKey_IDRecomputation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a public key with WRONG ID + wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: wrongID, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse should recompute the correct ID + parsed, _, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + + correctID := computeKeyID(pub) + assert.Equal(t, correctID, parsed.Metadata.ID) + assert.NotEqual(t, wrongID, parsed.Metadata.ID) +} + +// Test parsePublicKeyBundle + +func TestParsePublicKeyBundle_Single(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + keys, err := parsePublicKeyBundle(pemData, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 1) + assert.Equal(t, pub, keys[0].Key) +} + +func TestParsePublicKeyBundle_Multiple(t *testing.T) { + var bundle []byte + + // Create 3 keys + for i := 0; i < 3; i++ { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + bundle = append(bundle, pemData...) + } + + keys, err := parsePublicKeyBundle(bundle, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 3) +} + +func TestParsePublicKeyBundle_Empty(t *testing.T) { + _, err := parsePublicKeyBundle([]byte{}, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys found") +} + +func TestParsePublicKeyBundle_Invalid(t *testing.T) { + _, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic) + assert.Error(t, err) +} + +// Test PrivateKey + +func TestPrivateKey_JSONMarshaling(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + var decoded PrivateKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, privKey.Key, decoded.Key) + assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePrivateKey + +func TestParsePrivateKey_Valid(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + parsed, err := parsePrivateKey(pemData, tagRootPrivate) + require.NoError(t, err) + assert.Equal(t, priv, parsed.Key) +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePrivateKey_TrailingData(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + // Add trailing data + pemData = append(pemData, []byte("extra data")...) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "trailing PEM data") +} + +func TestParsePrivateKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePrivateKey_InvalidKeySize(t *testing.T) { + privKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +// Test verifyAny + +func TestVerifyAny_ValidSignature(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_InvalidSignature(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + invalidSignature := make([]byte, ed25519.SignatureSize) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, invalidSignature) + assert.False(t, result) +} + +func TestVerifyAny_MultipleKeys(t *testing.T) { + // Create 3 key pairs + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub3, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + {Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle + {Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_NoMatchingKey(t *testing.T) { + _, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + // Only include pub2, not pub1 + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_EmptyKeys(t *testing.T) { + message := []byte("test message") + signature := make([]byte, ed25519.SignatureSize) + + result := verifyAny([]PublicKey{}, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_TamperedMessage(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + {Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}}, + } + + // Verify with different message + tamperedMessage := []byte("different message") + result := verifyAny(rootKeys, tamperedMessage, signature) + assert.False(t, result) +} diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updatemanager/reposign/revocation.go new file mode 100644 index 000000000..e679e212f --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation.go @@ -0,0 +1,229 @@ +package reposign + +import ( + "crypto/ed25519" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour + defaultRevocationListExpiration = 365 * 24 * time.Hour +) + +type RevocationList struct { + Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time + LastUpdated time.Time `json:"last_updated"` // When the list was last modified + ExpiresAt time.Time `json:"expires_at"` // When the list expires +} + +func (rl RevocationList) MarshalJSON() ([]byte, error) { + // Convert map[KeyID]time.Time to map[string]time.Time + strMap := make(map[string]time.Time, len(rl.Revoked)) + for k, v := range rl.Revoked { + strMap[k.String()] = v + } + + return json.Marshal(map[string]interface{}{ + "revoked": strMap, + "last_updated": rl.LastUpdated, + "expires_at": rl.ExpiresAt, + }) +} + +func (rl *RevocationList) UnmarshalJSON(data []byte) error { + var temp struct { + Revoked map[string]time.Time `json:"revoked"` + LastUpdated time.Time `json:"last_updated"` + ExpiresAt time.Time `json:"expires_at"` + Version int `json:"version"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Convert map[string]time.Time back to map[KeyID]time.Time + rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked)) + for k, v := range temp.Revoked { + kid, err := ParseKeyID(k) + if err != nil { + return fmt.Errorf("failed to parse KeyID %q: %w", k, err) + } + rl.Revoked[kid] = v + } + + rl.LastUpdated = temp.LastUpdated + rl.ExpiresAt = temp.ExpiresAt + + return nil +} + +func ParseRevocationList(data []byte) (*RevocationList, error) { + var rl RevocationList + if err := json.Unmarshal(data, &rl); err != nil { + return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err) + } + + // Initialize the map if it's nil (in case of empty JSON object) + if rl.Revoked == nil { + rl.Revoked = make(map[KeyID]time.Time) + } + + if rl.LastUpdated.IsZero() { + return nil, fmt.Errorf("revocation list missing last_updated timestamp") + } + + if rl.ExpiresAt.IsZero() { + return nil, fmt.Errorf("revocation list missing expires_at timestamp") + } + + return &rl, nil +} + +func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) { + revoList, err := ParseRevocationList(data) + if err != nil { + log.Debugf("failed to parse revocation list: %s", err) + return nil, err + } + + now := time.Now().UTC() + + // Validate signature timestamp + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + if now.Sub(signature.Timestamp) > maxRevocationSignatureAge { + err := fmt.Errorf("revocation list signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + // Ensure LastUpdated is not in the future (with clock skew tolerance) + if revoList.LastUpdated.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated) + log.Errorf("rejecting future-dated revocation list: %v", err) + return nil, err + } + + // Check if the revocation list has expired + if now.After(revoList.ExpiresAt) { + err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now) + log.Errorf("rejecting expired revocation list: %v", err) + return nil, err + } + + // Ensure ExpiresAt is not in the future by more than the expected expiration window + // (allows some clock skew but prevents maliciously long expiration times) + if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) { + err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt) + log.Errorf("rejecting revocation list with invalid expiration: %v", err) + return nil, err + } + + // Validate signature timestamp is close to LastUpdated + // (prevents signing old lists with new timestamps) + timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs() + if timeDiff > maxClockSkew { + err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)", + signature.Timestamp, revoList.LastUpdated, timeDiff) + log.Errorf("timestamp mismatch in revocation list: %v", err) + return nil, err + } + + // Reconstruct the signed message: revocation_list_data || timestamp || version + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("revocation list verification failed") + } + return revoList, nil +} + +func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.UTC(), + ExpiresAt: now.Add(expiration).UTC(), + } + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now().UTC() + + rl.Revoked[kid] = now + rl.LastUpdated = now + rl.ExpiresAt = now.Add(expiration) + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) { + data, err := json.Marshal(rl) + if err != nil { + return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err) + } + + timestamp := time.Now().UTC() + + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(privateRootKey.Key, msg) + + signature := &Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: privateRootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return signature, nil +} diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updatemanager/reposign/revocation_test.go new file mode 100644 index 000000000..d6d748f3d --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation_test.go @@ -0,0 +1,860 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RevocationList marshaling/unmarshaling + +func TestRevocationList_MarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC) + + rl := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: expiresAt, + } + + jsonData, err := json.Marshal(rl) + require.NoError(t, err) + + // Verify it can be unmarshaled back + var decoded map[string]interface{} + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Contains(t, decoded, "revoked") + assert.Contains(t, decoded, "last_updated") + assert.Contains(t, decoded, "expires_at") +} + +func TestRevocationList_UnmarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + jsonData := map[string]interface{}{ + "revoked": map[string]string{ + keyID.String(): revokedTime.Format(time.RFC3339), + }, + "last_updated": lastUpdated.Format(time.RFC3339), + } + + jsonBytes, err := json.Marshal(jsonData) + require.NoError(t, err) + + var rl RevocationList + err = json.Unmarshal(jsonBytes, &rl) + require.NoError(t, err) + + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, keyID) + assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix()) +} + +func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + original := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC), + }, + LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC), + } + + // Marshal + jsonData, err := original.MarshalJSON() + require.NoError(t, err) + + // Unmarshal + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + // Verify + assert.Len(t, decoded.Revoked, 2) + assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix()) + assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix()) + assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix()) +} + +func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) { + jsonData := []byte(`{ + "revoked": { + "invalid_key_id": "2024-01-15T10:30:00Z" + }, + "last_updated": "2024-01-15T11:00:00Z" + }`) + + var rl RevocationList + err := json.Unmarshal(jsonData, &rl) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse KeyID") +} + +func TestRevocationList_EmptyRevoked(t *testing.T) { + rl := &RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC(), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + assert.Empty(t, decoded.Revoked) + assert.NotNil(t, decoded.Revoked) +} + +// Test ParseRevocationList + +func TestParseRevocationList_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + rl := RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed) + assert.Len(t, parsed.Revoked, 1) + assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix()) +} + +func TestParseRevocationList_InvalidJSON(t *testing.T) { + invalidJSON := []byte("not valid json") + + _, err := ParseRevocationList(invalidJSON) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParseRevocationList_MissingLastUpdated(t *testing.T) { + jsonData := []byte(`{ + "revoked": {} + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_EmptyObject(t *testing.T) { + jsonData := []byte(`{}`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_NilRevoked(t *testing.T) { + lastUpdated := time.Now().UTC() + expiresAt := lastUpdated.Add(90 * 24 * time.Hour) + jsonData := []byte(`{ + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `", + "expires_at": "` + expiresAt.Format(time.RFC3339) + `" + }`) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed.Revoked) + assert.Empty(t, parsed.Revoked) +} + +func TestParseRevocationList_MissingExpiresAt(t *testing.T) { + lastUpdated := time.Now().UTC() + jsonData := []byte(`{ + "revoked": {}, + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `" + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing expires_at") +} + +// Test ValidateRevocationList + +func TestValidateRevocationList_Valid(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + rl, err := ValidateRevocationList(rootKeys, rlData, *signature) + require.NoError(t, err) + assert.NotNil(t, rl) + assert.Empty(t, rl.Revoked) +} + +func TestValidateRevocationList_InvalidSignature(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Validate should fail + _, err = ValidateRevocationList(rootKeys, rlData, invalidSig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateRevocationList_FutureTimestamp(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be in the future + signature.Timestamp = time.Now().UTC().Add(10 * time.Minute) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateRevocationList_TooOld(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be too old + signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateRevocationList_InvalidJSON(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + signature := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + _, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature) + assert.Error(t, err) +} + +func TestValidateRevocationList_FutureLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with future LastUpdated + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(10 * time.Minute), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "LastUpdated is in the future") +} + +func TestValidateRevocationList_TimestampMismatch(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with LastUpdated far in the past + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(-1 * time.Hour), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it with current timestamp + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + // Modify signature timestamp to differ too much from LastUpdated + sig.Timestamp = time.Now().UTC() + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "differs too much") +} + +func TestValidateRevocationList_Expired(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list that expired in the past + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.Add(-100 * 24 * time.Hour), + ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + // Adjust signature timestamp to match LastUpdated + sig.Timestamp = rl.LastUpdated + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge) + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now, + ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too far in the future") +} + +// Test CreateRevocationList + +func TestCreateRevocationList_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + assert.NotEmpty(t, rlData) + assert.NotEmpty(t, sigData) + + // Verify it can be parsed + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + assert.False(t, rl.LastUpdated.IsZero()) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +// Test ExtendRevocationList + +func TestExtendRevocationList_AddKey(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Generate a key to revoke + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + // Extend the revocation list + newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Verify the new list + newRL, err := ParseRevocationList(newRLData) + require.NoError(t, err) + assert.Len(t, newRL.Revoked, 1) + assert.Contains(t, newRL.Revoked, revokedKeyID) + + // Verify signature + sig, err := ParseSignature(newSigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestExtendRevocationList_MultipleKeys(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add first key + key1Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key1ID := computeKeyID(key1Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // Add second key + key2Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key2ID := computeKeyID(key2Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 2) + assert.Contains(t, rl.Revoked, key1ID) + assert.Contains(t, rl.Revoked, key2ID) +} + +func TestExtendRevocationList_DuplicateKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add a key + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + firstRevocationTime := rl.Revoked[keyID] + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Add the same key again + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // The revocation time should be updated + secondRevocationTime := rl.Revoked[keyID] + assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime)) +} + +func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + firstLastUpdated := rl.LastUpdated + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Extend list + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + + // LastUpdated should be updated + assert.True(t, rl.LastUpdated.After(firstLastUpdated)) +} + +// Integration test + +func TestRevocationList_FullWorkflow(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 1: Create empty revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 2: Validate it + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + rl, err := ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Step 3: Revoke a key + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Validate the extended list + sig, err = ParseSignature(sigData) + require.NoError(t, err) + + rl, err = ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, revokedKeyID) + + // Step 5: Verify the revocation time is reasonable + revTime := rl.Revoked[revokedKeyID] + now := time.Now().UTC() + assert.True(t, revTime.Before(now) || revTime.Equal(now)) + assert.True(t, now.Sub(revTime) < time.Minute) +} diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updatemanager/reposign/root.go new file mode 100644 index 000000000..2c3ca54a0 --- /dev/null +++ b/client/internal/updatemanager/reposign/root.go @@ -0,0 +1,120 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "fmt" + "time" +) + +const ( + tagRootPrivate = "ROOT PRIVATE KEY" + tagRootPublic = "ROOT PUBLIC KEY" +) + +// RootKey is a root Key used to sign signing keys +type RootKey struct { + PrivateKey +} + +func (k RootKey) String() string { + return fmt.Sprintf( + "RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func ParseRootKey(privKeyPEM []byte) (*RootKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root Key: %w", err) + } + return &RootKey{pk}, nil +} + +// ParseRootPublicKey parses a root public key from PEM format +func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err) + } + return pk, nil +} + +// GenerateRootKey generates a new root Key pair with Metadata +func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) { + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, err + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + rk := &RootKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(rk.PrivateKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: pubJSON, + }) + + return rk, privPEM, pubPEM, nil +} + +func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) { + timestamp := time.Now().UTC() + + // This ensures the timestamp is cryptographically bound to the signature + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(rootKey.Key, msg) + // Create signature bundle with timestamp and Metadata + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: rootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updatemanager/reposign/root_test.go new file mode 100644 index 000000000..e75e29729 --- /dev/null +++ b/client/internal/updatemanager/reposign/root_test.go @@ -0,0 +1,476 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RootKey.String() + +func TestRootKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2034-01-15") +} + +func TestRootKey_String_NoExpiration(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, "0001-01-01") // Zero time format +} + +// Test GenerateRootKey + +func TestGenerateRootKey_Valid(t *testing.T) { + expiration := 10 * 365 * 24 * time.Hour // 10 years + + rk, privPEM, pubPEM, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Verify the key has correct metadata + assert.False(t, rk.Metadata.CreatedAt.IsZero()) + assert.False(t, rk.Metadata.ExpiresAt.IsZero()) + assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt)) + + // Verify expiration is approximately correct + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ShortExpiration(t *testing.T) { + expiration := 24 * time.Hour // 1 day + + rk, _, _, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + + // Verify expiration + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ZeroExpiration(t *testing.T) { + rk, _, _, err := GenerateRootKey(0) + require.NoError(t, err) + assert.NotNil(t, rk) + + // With zero expiration, ExpiresAt should be equal to CreatedAt + assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt) +} + +func TestGenerateRootKey_PEMFormat(t *testing.T) { + rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Verify private key PEM + privBlock, _ := pem.Decode(privPEM) + require.NotNil(t, privBlock) + assert.Equal(t, tagRootPrivate, privBlock.Type) + + var privKey PrivateKey + err = json.Unmarshal(privBlock.Bytes, &privKey) + require.NoError(t, err) + assert.Equal(t, rk.Key, privKey.Key) + + // Verify public key PEM + pubBlock, _ := pem.Decode(pubPEM) + require.NotNil(t, pubBlock) + assert.Equal(t, tagRootPublic, pubBlock.Type) + + var pubKey PublicKey + err = json.Unmarshal(pubBlock.Bytes, &pubKey) + require.NoError(t, err) + assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID) +} + +func TestGenerateRootKey_KeySize(t *testing.T) { + rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Ed25519 private key should be 64 bytes + assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key)) + + // Ed25519 public key should be 32 bytes + pubKey := rk.Key.Public().(ed25519.PublicKey) + assert.Equal(t, ed25519.PublicKeySize, len(pubKey)) +} + +func TestGenerateRootKey_UniqueKeys(t *testing.T) { + rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Different keys should have different IDs + assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID) + assert.NotEqual(t, rk1.Key, rk2.Key) +} + +// Test ParseRootKey + +func TestParseRootKey_Valid(t *testing.T) { + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.NotNil(t, parsed) + + // Verify the parsed key matches the original + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) + assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix()) + assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix()) +} + +func TestParseRootKey_InvalidPEM(t *testing.T) { + _, err := ParseRootKey([]byte("not a valid PEM")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseRootKey_EmptyData(t *testing.T) { + _, err := ParseRootKey([]byte{}) + assert.Error(t, err) +} + +func TestParseRootKey_WrongType(t *testing.T) { + // Generate an artifact key instead of root key + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Try to parse artifact key as root key + _, err = ParseRootKey(privPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") + + // Just to use artifactKey to avoid unused variable warning + _ = artifactKey +} + +func TestParseRootKey_CorruptedJSON(t *testing.T) { + // Create PEM with corrupted JSON + corruptedPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: []byte("corrupted json data"), + }) + + _, err := ParseRootKey(corruptedPEM) + assert.Error(t, err) +} + +func TestParseRootKey_InvalidKeySize(t *testing.T) { + // Create a key with invalid size + invalidKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + privJSON, err := json.Marshal(invalidKey) + require.NoError(t, err) + + invalidPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + _, err = ParseRootKey(invalidPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +func TestParseRootKey_Roundtrip(t *testing.T) { + // Generate a key + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse it + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + + // Generate PEM again from parsed key + privJSON2, err := json.Marshal(parsed.PrivateKey) + require.NoError(t, err) + + privPEM2 := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON2, + }) + + // Parse again + parsed2, err := ParseRootKey(privPEM2) + require.NoError(t, err) + + // Should still match original + assert.Equal(t, original.Key, parsed2.Key) + assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID) +} + +// Test SignArtifactKey + +func TestSignArtifactKey_Valid(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Parse and verify signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, rootKey.Metadata.ID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "sha512", sig.HashAlgo) + assert.False(t, sig.Timestamp.IsZero()) +} + +func TestSignArtifactKey_EmptyData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + sigData, err := SignArtifactKey(*rootKey, []byte{}) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Should still be able to parse + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_Verify(t *testing.T) { + rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse public key + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + // Sign some data + data := []byte("test data for verification") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Parse signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Reconstruct message + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify signature + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid) +} + +func TestSignArtifactKey_DifferentData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data1 := []byte("data1") + data2 := []byte("data2") + + sig1, err := SignArtifactKey(*rootKey, data1) + require.NoError(t, err) + + sig2, err := SignArtifactKey(*rootKey, data2) + require.NoError(t, err) + + // Different data should produce different signatures + assert.NotEqual(t, sig1, sig2) +} + +func TestSignArtifactKey_MultipleSignatures(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data") + + // Sign twice with a small delay + sig1, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + sig2, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Signatures should be different due to different timestamps + assert.NotEqual(t, sig1, sig2) + + // Parse both signatures + parsed1, err := ParseSignature(sig1) + require.NoError(t, err) + + parsed2, err := ParseSignature(sig2) + require.NoError(t, err) + + // Timestamps should be different + assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp)) +} + +func TestSignArtifactKey_LargeData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Create 1MB of data + largeData := make([]byte, 1024*1024) + for i := range largeData { + largeData[i] = byte(i % 256) + } + + sigData, err := SignArtifactKey(*rootKey, largeData) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_TimestampInSignature(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + beforeSign := time.Now().UTC() + data := []byte("test data") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + afterSign := time.Now().UTC() + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Timestamp should be between before and after + assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second))) + assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second))) +} + +// Integration test + +func TestRootKey_FullWorkflow(t *testing.T) { + // Step 1: Generate root key + rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + assert.NotNil(t, rootKey) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Step 2: Parse the private key back + parsedRootKey, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.Equal(t, rootKey.Key, parsedRootKey.Key) + assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID) + + // Step 3: Generate an artifact key using root key + artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, artifactKey) + + // Step 4: Verify the artifact key signature + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + sig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic) + require.NoError(t, err) + + // Reconstruct message - SignArtifactKey signs the PEM, not the JSON + msg := make([]byte, 0, len(artifactPubPEM)+8) + msg = append(msg, artifactPubPEM...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify with root public key + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid, "Artifact key signature should be valid") + + // Step 5: Use artifact key to sign data + testData := []byte("This is test artifact data") + dataSig, err := SignData(*artifactKey, testData) + require.NoError(t, err) + assert.NotEmpty(t, dataSig) + + // Step 6: Verify the artifact data signature + dataSigParsed, err := ParseSignature(dataSig) + require.NoError(t, err) + + err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed) + assert.NoError(t, err, "Artifact data signature should be valid") +} + +func TestRootKey_ExpiredKeyWorkflow(t *testing.T) { + // Generate a root key that expires very soon + rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to generate artifact key with expired root key + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updatemanager/reposign/signature.go new file mode 100644 index 000000000..c7f06e94e --- /dev/null +++ b/client/internal/updatemanager/reposign/signature.go @@ -0,0 +1,24 @@ +package reposign + +import ( + "encoding/json" + "time" +) + +// Signature contains a signature with associated Metadata +type Signature struct { + Signature []byte `json:"signature"` + Timestamp time.Time `json:"timestamp"` + KeyID KeyID `json:"key_id"` + Algorithm string `json:"algorithm"` // "ed25519" + HashAlgo string `json:"hash_algo"` // "blake2s" or sha512 +} + +func ParseSignature(data []byte) (*Signature, error) { + var signature Signature + if err := json.Unmarshal(data, &signature); err != nil { + return nil, err + } + + return &signature, nil +} diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updatemanager/reposign/signature_test.go new file mode 100644 index 000000000..1960c5518 --- /dev/null +++ b/client/internal/updatemanager/reposign/signature_test.go @@ -0,0 +1,277 @@ +package reposign + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseSignature_Valid(t *testing.T) { + timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + keyID, err := ParseKeyID("0123456789abcdef") + require.NoError(t, err) + + signatureData := []byte{0x01, 0x02, 0x03, 0x04} + + jsonData, err := json.Marshal(Signature{ + Signature: signatureData, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + }) + require.NoError(t, err) + + sig, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Equal(t, signatureData, sig.Signature) + assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix()) + assert.Equal(t, keyID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestParseSignature_InvalidJSON(t *testing.T) { + invalidJSON := []byte(`{invalid json}`) + + sig, err := ParseSignature(invalidJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_EmptyData(t *testing.T) { + emptyJSON := []byte(`{}`) + + sig, err := ParseSignature(emptyJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Empty(t, sig.Signature) + assert.True(t, sig.Timestamp.IsZero()) + assert.Equal(t, KeyID{}, sig.KeyID) + assert.Empty(t, sig.Algorithm) + assert.Empty(t, sig.HashAlgo) +} + +func TestParseSignature_MissingFields(t *testing.T) { + // JSON with only some fields + partialJSON := []byte(`{ + "signature": "AQIDBA==", + "algorithm": "ed25519" + }`) + + sig, err := ParseSignature(partialJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.True(t, sig.Timestamp.IsZero()) +} + +func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) { + timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC) + keyID, err := ParseKeyID("fedcba9876543210") + require.NoError(t, err) + + original := Signature{ + Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Marshal + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // Verify + assert.Equal(t, original.Signature, parsed.Signature) + assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix()) + assert.Equal(t, original.KeyID, parsed.KeyID) + assert.Equal(t, original.Algorithm, parsed.Algorithm) + assert.Equal(t, original.HashAlgo, parsed.HashAlgo) +} + +func TestSignature_NilSignatureBytes(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("0011223344556677") + require.NoError(t, err) + + sig := Signature{ + Signature: nil, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Nil(t, parsed.Signature) +} + +func TestSignature_LargeSignature(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("aabbccddeeff0011") + require.NoError(t, err) + + // Create a large signature (64 bytes for ed25519) + largeSignature := make([]byte, 64) + for i := range largeSignature { + largeSignature[i] = byte(i) + } + + sig := Signature{ + Signature: largeSignature, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, largeSignature, parsed.Signature) +} + +func TestSignature_WithDifferentHashAlgorithms(t *testing.T) { + tests := []struct { + name string + hashAlgo string + }{ + {"blake2s", "blake2s"}, + {"sha512", "sha512"}, + {"sha256", "sha256"}, + {"empty", ""}, + } + + keyID, err := ParseKeyID("1122334455667788") + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sig := Signature{ + Signature: []byte{0x01, 0x02}, + Timestamp: time.Now().UTC(), + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: tt.hashAlgo, + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, tt.hashAlgo, parsed.HashAlgo) + }) + } +} + +func TestSignature_TimestampPrecision(t *testing.T) { + // Test that timestamp preserves precision through JSON marshaling + timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC) + keyID, err := ParseKeyID("8877665544332211") + require.NoError(t, err) + + sig := Signature{ + Signature: []byte{0xaa, 0xbb}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // JSON timestamps typically have second or millisecond precision + // so we check that at least seconds match + assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix()) +} + +func TestParseSignature_MalformedKeyID(t *testing.T) { + // Test with a malformed KeyID field + malformedJSON := []byte(`{ + "signature": "AQID", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "invalid_keyid_format", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + // This should fail since "invalid_keyid_format" is not a valid KeyID + sig, err := ParseSignature(malformedJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_InvalidTimestamp(t *testing.T) { + // Test with an invalid timestamp format + invalidTimestampJSON := []byte(`{ + "signature": "AQID", + "timestamp": "not-a-timestamp", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + sig, err := ParseSignature(invalidTimestampJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestSignature_ZeroKeyID(t *testing.T) { + // Test with a zero KeyID + sig := Signature{ + Signature: []byte{0x01, 0x02, 0x03}, + Timestamp: time.Now().UTC(), + KeyID: KeyID{}, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, KeyID{}, parsed.KeyID) +} + +func TestParseSignature_ExtraFields(t *testing.T) { + // JSON with extra fields that should be ignored + jsonWithExtra := []byte(`{ + "signature": "AQIDBA==", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s", + "extra_field": "should be ignored", + "another_extra": 12345 + }`) + + sig, err := ParseSignature(jsonWithExtra) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updatemanager/reposign/verify.go new file mode 100644 index 000000000..0af2a8c9e --- /dev/null +++ b/client/internal/updatemanager/reposign/verify.go @@ -0,0 +1,187 @@ +package reposign + +import ( + "context" + "fmt" + "net/url" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" +) + +const ( + artifactPubKeysFileName = "artifact-key-pub.pem" + artifactPubKeysSigFileName = "artifact-key-pub.pem.sig" + revocationFileName = "revocation-list.json" + revocationSignFileName = "revocation-list.json.sig" + + keySizeLimit = 5 * 1024 * 1024 //5MB + signatureLimit = 1024 + revocationLimit = 10 * 1024 * 1024 +) + +type ArtifactVerify struct { + rootKeys []PublicKey + keysBaseURL *url.URL + + revocationList *RevocationList +} + +func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) { + allKeys, err := loadEmbeddedPublicKeys() + if err != nil { + return nil, err + } + + return newArtifactVerify(keysBaseURL, allKeys) +} + +func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) { + ku, err := url.Parse(keysBaseURL) + if err != nil { + return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err) + } + + a := &ArtifactVerify{ + rootKeys: allKeys, + keysBaseURL: ku, + } + return a, nil +} + +func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error { + version = strings.TrimPrefix(version, "v") + + revocationList, err := a.loadRevocationList(ctx) + if err != nil { + return fmt.Errorf("failed to load revocation list: %v", err) + } + a.revocationList = revocationList + + artifactPubKeys, err := a.loadArtifactKeys(ctx) + if err != nil { + return fmt.Errorf("failed to load artifact keys: %v", err) + } + + signature, err := a.loadArtifactSignature(ctx, version, artifactFile) + if err != nil { + return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + log.Errorf("failed to read artifact file: %v", err) + return fmt.Errorf("failed to read artifact file: %w", err) + } + + if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil { + return fmt.Errorf("failed to validate artifact: %v", err) + } + + return nil +} + +func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String() + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse revocation list signature: %s", err) + return nil, err + } + + return ValidateRevocationList(a.rootKeys, data, *signature) +} + +func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String() + log.Debugf("starting downloading artifact keys from: %s", downloadURL) + data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit) + if err != nil { + log.Debugf("failed to download artifact keys: %s", err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String() + log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL) + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download signature of public keys: %s", err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse signature of public keys: %s", err) + return nil, err + } + + return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList) +} + +func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) { + artifactFile = filepath.Base(artifactFile) + downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download artifact signature: %s", err) + return nil, err + } + + signature, err := ParseSignature(data) + if err != nil { + log.Debugf("failed to parse artifact signature: %s", err) + return nil, err + } + + return signature, nil + +} + +func loadEmbeddedPublicKeys() ([]PublicKey, error) { + files, err := embeddedCerts.ReadDir(embeddedCertsDir) + if err != nil { + return nil, fmt.Errorf("failed to read embedded certs: %w", err) + } + + var allKeys []PublicKey + for _, file := range files { + if file.IsDir() { + continue + } + + data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name()) + if err != nil { + return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err) + } + + keys, err := parsePublicKeyBundle(data, tagRootPublic) + if err != nil { + return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err) + } + + allKeys = append(allKeys, keys...) + } + + if len(allKeys) == 0 { + return nil, fmt.Errorf("no valid public keys found in embedded certs") + } + + return allKeys, nil +} diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updatemanager/reposign/verify_test.go new file mode 100644 index 000000000..c29393bad --- /dev/null +++ b/client/internal/updatemanager/reposign/verify_test.go @@ -0,0 +1,528 @@ +package reposign + +import ( + "context" + "crypto/ed25519" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactVerify construction + +func TestArtifactVerify_Construction(t *testing.T) { + // Generate test root key + rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + assert.NotNil(t, av) + assert.NotEmpty(t, av.rootKeys) + assert.Equal(t, keysBaseURL, av.keysBaseURL.String()) + + // Verify root key structure + assert.NotEmpty(t, av.rootKeys[0].Key) + assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID) + assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero()) +} + +func TestArtifactVerify_MultipleRootKeys(t *testing.T) { + // Generate multiple test root keys + rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic) + require.NoError(t, err) + + rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2}) + assert.NoError(t, err) + assert.Len(t, av.rootKeys, 2) + assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID) +} + +// Test Verify workflow with mock HTTP server + +func TestArtifactVerify_FullWorkflow(t *testing.T) { + // Create temporary test directory + tempDir := t.TempDir() + + // Step 1: Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Step 2: Generate artifact key + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Step 3: Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Bundle artifact keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Step 5: Create test artifact + artifactPath := filepath.Join(tempDir, "test-artifact.bin") + artifactData := []byte("This is test artifact data for verification") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + // Step 6: Sign artifact + artifactSigData, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 7: Setup mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifacts/v1.0.0/test-artifact.bin": + _, _ = w.Write(artifactData) + case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + // Step 8: Create ArtifactVerify with test root key + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Step 9: Verify artifact + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.NoError(t, err) +} + +func TestArtifactVerify_InvalidRevocationList(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Setup server with invalid revocation list + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write([]byte("invalid data")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load revocation list") +} + +func TestArtifactVerify_MissingArtifactFile(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Create signature for non-existent file + testData := []byte("test") + artifactSigData, err := SignData(*artifactKey, testData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/missing.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", "file.bin") + assert.Error(t, err) +} + +func TestArtifactVerify_ServerUnavailable(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Use URL that doesn't exist + av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_ContextCancellation(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Create a server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Create context that cancels quickly + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_WithRevocation(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + _, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by revoked key + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey1, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because the signing key is revoked + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestArtifactVerify_ValidWithSecondKey(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + _, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by second key (not revoked) + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey2, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should succeed because second key is not revoked + assert.NoError(t, err) +} + +func TestArtifactVerify_TamperedArtifact(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key and artifact key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + artifactSigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + // Write tampered data to file + artifactPath := filepath.Join(tempDir, "test.bin") + tamperedData := []byte("tampered data") + err = os.WriteFile(artifactPath, tamperedData, 0644) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because artifact was tampered + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate artifact") +} + +// Test URL validation + +func TestArtifactVerify_URLParsing(t *testing.T) { + tests := []struct { + name string + keysBaseURL string + expectError bool + }{ + { + name: "Valid HTTP URL", + keysBaseURL: "http://example.com/artifact-signatures", + expectError: false, + }, + { + name: "Valid HTTPS URL", + keysBaseURL: "https://example.com/artifact-signatures", + expectError: false, + }, + { + name: "URL with port", + keysBaseURL: "http://localhost:8080/artifact-signatures", + expectError: false, + }, + { + name: "Invalid URL", + keysBaseURL: "://invalid", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newArtifactVerify(tt.keysBaseURL, nil) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updatemanager/update.go new file mode 100644 index 000000000..875b50b49 --- /dev/null +++ b/client/internal/updatemanager/update.go @@ -0,0 +1,11 @@ +package updatemanager + +import v "github.com/hashicorp/go-version" + +type UpdateInterface interface { + StopWatch() + SetDaemonVersion(newVersion string) bool + SetOnUpdateListener(updateFn func()) + LatestVersion() *v.Version + StartFetcher() +} diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..f3458ccea 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -1,9 +1,12 @@ +//go:build ios + package NetBirdSDK import ( "context" "fmt" "net/netip" + "os" "sort" "strings" "sync" @@ -20,8 +23,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s } // Run start the internal client. It is a blocker function -func (c *Client) Run(fd int32, interfaceName string) error { +func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { + exportEnvList(envList) log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ @@ -127,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } @@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "") if err != nil { return err.Error() } @@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID { } return netIDs } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k)) + log.Debugf("Setting env variable %s: %s", k, v) + + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } else { + log.Debugf("Env variable %s was set successfully", k) + } + } +} diff --git a/client/ios/NetBirdSDK/env_list.go b/client/ios/NetBirdSDK/env_list.go new file mode 100644 index 000000000..4800803d7 --- /dev/null +++ b/client/ios/NetBirdSDK/env_list.go @@ -0,0 +1,34 @@ +//go:build ios + +package NetBirdSDK + +import "github.com/netbirdio/netbird/client/internal/peer" + +// EnvList is an exported struct to be bound by gomobile +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} + +// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client +func GetEnvKeyNBForceRelay() string { + return peer.EnvKeyNBForceRelay +} diff --git a/client/ios/NetBirdSDK/gomobile.go b/client/ios/NetBirdSDK/gomobile.go index 9eadd6a7f..79bf0c2ac 100644 --- a/client/ios/NetBirdSDK/gomobile.go +++ b/client/ios/NetBirdSDK/gomobile.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import _ "golang.org/x/mobile/bind" diff --git a/client/ios/NetBirdSDK/logger.go b/client/ios/NetBirdSDK/logger.go index f1ad1b9f6..531d0ba89 100644 --- a/client/ios/NetBirdSDK/logger.go +++ b/client/ios/NetBirdSDK/logger.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 570c44f80..1c2b38a61 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index 16c5039eb..9b00568be 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 5e7050465..39ae06538 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index 780443a7b..5f75e7c9a 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 30d0d0d0a..7b84d6e1c 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // RoutesSelectionInfoCollection made for Java layer to get non default types as collection diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/client/netbird.wxs b/client/netbird.wxs index ba827debf..03221dd91 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -51,7 +51,7 @@ - + diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 34230a5b4..80e5bb9c5 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,21 +1,20 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v6.33.1 // source: daemon.proto package proto import ( - reflect "reflect" - sync "sync" - unsafe "unsafe" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" _ "google.golang.org/protobuf/types/descriptorpb" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" ) const ( @@ -89,6 +88,56 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +// avoid collision with loglevel enum +type OSLifecycleRequest_CycleType int32 + +const ( + OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0 + OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1 + OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2 +) + +// Enum value maps for OSLifecycleRequest_CycleType. +var ( + OSLifecycleRequest_CycleType_name = map[int32]string{ + 0: "UNKNOWN", + 1: "SLEEP", + 2: "WAKEUP", + } + OSLifecycleRequest_CycleType_value = map[string]int32{ + "UNKNOWN": 0, + "SLEEP": 1, + "WAKEUP": 2, + } +) + +func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType { + p := new(OSLifecycleRequest_CycleType) + *p = x + return p +} + +func (x OSLifecycleRequest_CycleType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { + return file_daemon_proto_enumTypes[1].Descriptor() +} + +func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { + return &file_daemon_proto_enumTypes[1] +} + +func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead. +func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1, 0} +} + type SystemEvent_Severity int32 const ( @@ -125,11 +174,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[1].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[1] + return &file_daemon_proto_enumTypes[2] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -138,7 +187,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 0} + return file_daemon_proto_rawDescGZIP(), []int{53, 0} } type SystemEvent_Category int32 @@ -180,11 +229,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -193,7 +242,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 1} + return file_daemon_proto_rawDescGZIP(), []int{53, 1} } type EmptyRequest struct { @@ -232,6 +281,86 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +type OSLifecycleRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleRequest) Reset() { + *x = OSLifecycleRequest{} + mi := &file_daemon_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleRequest) ProtoMessage() {} + +func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead. +func (*OSLifecycleRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1} +} + +func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType { + if x != nil { + return x.Type + } + return OSLifecycleRequest_UNKNOWN +} + +type OSLifecycleResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleResponse) Reset() { + *x = OSLifecycleResponse{} + mi := &file_daemon_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleResponse) ProtoMessage() {} + +func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead. +func (*OSLifecycleResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{2} +} + type LoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // setupKey netbird setup key. @@ -280,13 +409,21 @@ type LoginRequest struct { ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // hint is used to pre-fill the email/username field during SSO authentication + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,34,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,35,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { *x = LoginRequest{} - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -298,7 +435,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -311,7 +448,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1} + return file_daemon_proto_rawDescGZIP(), []int{3} } func (x *LoginRequest) GetSetupKey() string { @@ -539,6 +676,55 @@ func (x *LoginRequest) GetMtu() int64 { return 0 } +func (x *LoginRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + +func (x *LoginRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *LoginRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *LoginRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *LoginRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *LoginRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -551,7 +737,7 @@ type LoginResponse struct { func (x *LoginResponse) Reset() { *x = LoginResponse{} - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -563,7 +749,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -576,7 +762,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{2} + return file_daemon_proto_rawDescGZIP(), []int{4} } func (x *LoginResponse) GetNeedsSSOLogin() bool { @@ -617,7 +803,7 @@ type WaitSSOLoginRequest struct { func (x *WaitSSOLoginRequest) Reset() { *x = WaitSSOLoginRequest{} - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -629,7 +815,7 @@ func (x *WaitSSOLoginRequest) String() string { func (*WaitSSOLoginRequest) ProtoMessage() {} func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -642,7 +828,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead. func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{3} + return file_daemon_proto_rawDescGZIP(), []int{5} } func (x *WaitSSOLoginRequest) GetUserCode() string { @@ -668,7 +854,7 @@ type WaitSSOLoginResponse struct { func (x *WaitSSOLoginResponse) Reset() { *x = WaitSSOLoginResponse{} - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -680,7 +866,7 @@ func (x *WaitSSOLoginResponse) String() string { func (*WaitSSOLoginResponse) ProtoMessage() {} func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -693,7 +879,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead. func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{4} + return file_daemon_proto_rawDescGZIP(), []int{6} } func (x *WaitSSOLoginResponse) GetEmail() string { @@ -707,13 +893,14 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` + AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *UpRequest) Reset() { *x = UpRequest{} - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -725,7 +912,7 @@ func (x *UpRequest) String() string { func (*UpRequest) ProtoMessage() {} func (x *UpRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -738,7 +925,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpRequest.ProtoReflect.Descriptor instead. func (*UpRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{5} + return file_daemon_proto_rawDescGZIP(), []int{7} } func (x *UpRequest) GetProfileName() string { @@ -755,6 +942,13 @@ func (x *UpRequest) GetUsername() string { return "" } +func (x *UpRequest) GetAutoUpdate() bool { + if x != nil && x.AutoUpdate != nil { + return *x.AutoUpdate + } + return false +} + type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -763,7 +957,7 @@ type UpResponse struct { func (x *UpResponse) Reset() { *x = UpResponse{} - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -775,7 +969,7 @@ func (x *UpResponse) String() string { func (*UpResponse) ProtoMessage() {} func (x *UpResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -788,7 +982,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpResponse.ProtoReflect.Descriptor instead. func (*UpResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{6} + return file_daemon_proto_rawDescGZIP(), []int{8} } type StatusRequest struct { @@ -803,7 +997,7 @@ type StatusRequest struct { func (x *StatusRequest) Reset() { *x = StatusRequest{} - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -815,7 +1009,7 @@ func (x *StatusRequest) String() string { func (*StatusRequest) ProtoMessage() {} func (x *StatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -828,7 +1022,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. func (*StatusRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{7} + return file_daemon_proto_rawDescGZIP(), []int{9} } func (x *StatusRequest) GetGetFullPeerStatus() bool { @@ -865,7 +1059,7 @@ type StatusResponse struct { func (x *StatusResponse) Reset() { *x = StatusResponse{} - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -877,7 +1071,7 @@ func (x *StatusResponse) String() string { func (*StatusResponse) ProtoMessage() {} func (x *StatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -890,7 +1084,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. func (*StatusResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{8} + return file_daemon_proto_rawDescGZIP(), []int{10} } func (x *StatusResponse) GetStatus() string { @@ -922,7 +1116,7 @@ type DownRequest struct { func (x *DownRequest) Reset() { *x = DownRequest{} - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -934,7 +1128,7 @@ func (x *DownRequest) String() string { func (*DownRequest) ProtoMessage() {} func (x *DownRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -947,7 +1141,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DownRequest.ProtoReflect.Descriptor instead. func (*DownRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{9} + return file_daemon_proto_rawDescGZIP(), []int{11} } type DownResponse struct { @@ -958,7 +1152,7 @@ type DownResponse struct { func (x *DownResponse) Reset() { *x = DownResponse{} - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -970,7 +1164,7 @@ func (x *DownResponse) String() string { func (*DownResponse) ProtoMessage() {} func (x *DownResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -983,7 +1177,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DownResponse.ProtoReflect.Descriptor instead. func (*DownResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{10} + return file_daemon_proto_rawDescGZIP(), []int{12} } type GetConfigRequest struct { @@ -996,7 +1190,7 @@ type GetConfigRequest struct { func (x *GetConfigRequest) Reset() { *x = GetConfigRequest{} - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1008,7 +1202,7 @@ func (x *GetConfigRequest) String() string { func (*GetConfigRequest) ProtoMessage() {} func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1021,7 +1215,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead. func (*GetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{11} + return file_daemon_proto_rawDescGZIP(), []int{13} } func (x *GetConfigRequest) GetProfileName() string { @@ -1049,30 +1243,36 @@ type GetConfigResponse struct { // preSharedKey settings value. PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"` // adminURL settings value. - AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` - InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` - WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` - Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` - DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` - BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` - DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` - DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` - BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` - DisableFirewall bool `protobuf:"varint,21,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` + WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` + Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` + DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` + DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` + DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` + BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` + EnableSSHRoot bool `protobuf:"varint,21,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"` + DisableFirewall bool `protobuf:"varint,27,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetConfigResponse) Reset() { *x = GetConfigResponse{} - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1084,7 +1284,7 @@ func (x *GetConfigResponse) String() string { func (*GetConfigResponse) ProtoMessage() {} func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1097,7 +1297,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead. func (*GetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{12} + return file_daemon_proto_rawDescGZIP(), []int{14} } func (x *GetConfigResponse) GetManagementUrl() string { @@ -1240,6 +1440,48 @@ func (x *GetConfigResponse) GetBlockLanAccess() bool { return false } +func (x *GetConfigResponse) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *GetConfigResponse) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + +func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 { + if x != nil { + return x.SshJWTCacheTTL + } + return 0 +} + func (x *GetConfigResponse) GetDisableFirewall() bool { if x != nil { return x.DisableFirewall @@ -1267,13 +1509,14 @@ type PeerState struct { Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` + SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PeerState) Reset() { *x = PeerState{} - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1285,7 +1528,7 @@ func (x *PeerState) String() string { func (*PeerState) ProtoMessage() {} func (x *PeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1298,7 +1541,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerState.ProtoReflect.Descriptor instead. func (*PeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{13} + return file_daemon_proto_rawDescGZIP(), []int{15} } func (x *PeerState) GetIP() string { @@ -1420,6 +1663,13 @@ func (x *PeerState) GetRelayAddress() string { return "" } +func (x *PeerState) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1436,7 +1686,7 @@ type LocalPeerState struct { func (x *LocalPeerState) Reset() { *x = LocalPeerState{} - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1448,7 +1698,7 @@ func (x *LocalPeerState) String() string { func (*LocalPeerState) ProtoMessage() {} func (x *LocalPeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1461,7 +1711,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead. func (*LocalPeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{14} + return file_daemon_proto_rawDescGZIP(), []int{16} } func (x *LocalPeerState) GetIP() string { @@ -1525,7 +1775,7 @@ type SignalState struct { func (x *SignalState) Reset() { *x = SignalState{} - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1537,7 +1787,7 @@ func (x *SignalState) String() string { func (*SignalState) ProtoMessage() {} func (x *SignalState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1550,7 +1800,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message { // Deprecated: Use SignalState.ProtoReflect.Descriptor instead. func (*SignalState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{15} + return file_daemon_proto_rawDescGZIP(), []int{17} } func (x *SignalState) GetURL() string { @@ -1586,7 +1836,7 @@ type ManagementState struct { func (x *ManagementState) Reset() { *x = ManagementState{} - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1598,7 +1848,7 @@ func (x *ManagementState) String() string { func (*ManagementState) ProtoMessage() {} func (x *ManagementState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1611,7 +1861,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message { // Deprecated: Use ManagementState.ProtoReflect.Descriptor instead. func (*ManagementState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{16} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *ManagementState) GetURL() string { @@ -1647,7 +1897,7 @@ type RelayState struct { func (x *RelayState) Reset() { *x = RelayState{} - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1659,7 +1909,7 @@ func (x *RelayState) String() string { func (*RelayState) ProtoMessage() {} func (x *RelayState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1672,7 +1922,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayState.ProtoReflect.Descriptor instead. func (*RelayState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{19} } func (x *RelayState) GetURI() string { @@ -1708,7 +1958,7 @@ type NSGroupState struct { func (x *NSGroupState) Reset() { *x = NSGroupState{} - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1720,7 +1970,7 @@ func (x *NSGroupState) String() string { func (*NSGroupState) ProtoMessage() {} func (x *NSGroupState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1733,7 +1983,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message { // Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. func (*NSGroupState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{18} + return file_daemon_proto_rawDescGZIP(), []int{20} } func (x *NSGroupState) GetServers() []string { @@ -1764,6 +2014,128 @@ func (x *NSGroupState) GetError() string { return "" } +// SSHSessionInfo contains information about an active SSH session +type SSHSessionInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` + Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` + JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHSessionInfo) Reset() { + *x = SSHSessionInfo{} + mi := &file_daemon_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHSessionInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHSessionInfo) ProtoMessage() {} + +func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[21] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. +func (*SSHSessionInfo) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{21} +} + +func (x *SSHSessionInfo) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *SSHSessionInfo) GetRemoteAddress() string { + if x != nil { + return x.RemoteAddress + } + return "" +} + +func (x *SSHSessionInfo) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +func (x *SSHSessionInfo) GetJwtUsername() string { + if x != nil { + return x.JwtUsername + } + return "" +} + +// SSHServerState contains the latest state of the SSH server +type SSHServerState struct { + state protoimpl.MessageState `protogen:"open.v1"` + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` + Sessions []*SSHSessionInfo `protobuf:"bytes,2,rep,name=sessions,proto3" json:"sessions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHServerState) Reset() { + *x = SSHServerState{} + mi := &file_daemon_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHServerState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHServerState) ProtoMessage() {} + +func (x *SSHServerState) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[22] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. +func (*SSHServerState) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{22} +} + +func (x *SSHServerState) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *SSHServerState) GetSessions() []*SSHSessionInfo { + if x != nil { + return x.Sessions + } + return nil +} + // FullStatus contains the full state held by the Status instance type FullStatus struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1776,13 +2148,14 @@ type FullStatus struct { NumberOfForwardingRules int32 `protobuf:"varint,8,opt,name=NumberOfForwardingRules,proto3" json:"NumberOfForwardingRules,omitempty"` Events []*SystemEvent `protobuf:"bytes,7,rep,name=events,proto3" json:"events,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,9,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + SshServerState *SSHServerState `protobuf:"bytes,10,opt,name=sshServerState,proto3" json:"sshServerState,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1794,7 +2167,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1807,7 +2180,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -1873,6 +2246,13 @@ func (x *FullStatus) GetLazyConnectionEnabled() bool { return false } +func (x *FullStatus) GetSshServerState() *SSHServerState { + if x != nil { + return x.SshServerState + } + return nil +} + // Networks type ListNetworksRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1882,7 +2262,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1894,7 +2274,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1907,7 +2287,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{24} } type ListNetworksResponse struct { @@ -1919,7 +2299,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1931,7 +2311,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1944,7 +2324,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{25} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -1965,7 +2345,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1977,7 +2357,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1990,7 +2370,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2022,7 +2402,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2034,7 +2414,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2047,7 +2427,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{27} } type IPList struct { @@ -2059,7 +2439,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2071,7 +2451,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2084,7 +2464,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *IPList) GetIps() []string { @@ -2107,7 +2487,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2119,7 +2499,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2132,7 +2512,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *Network) GetID() string { @@ -2184,7 +2564,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2196,7 +2576,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2209,7 +2589,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2266,7 +2646,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2278,7 +2658,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2291,7 +2671,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *ForwardingRule) GetProtocol() string { @@ -2338,7 +2718,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2350,7 +2730,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2363,7 +2743,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2387,7 +2767,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2399,7 +2779,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2412,7 +2792,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{33} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2461,7 +2841,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2473,7 +2853,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2486,7 +2866,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *DebugBundleResponse) GetPath() string { @@ -2518,7 +2898,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2530,7 +2910,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2543,7 +2923,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{35} } type GetLogLevelResponse struct { @@ -2555,7 +2935,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2567,7 +2947,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2580,7 +2960,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{36} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -2599,7 +2979,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2611,7 +2991,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2624,7 +3004,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -2642,7 +3022,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2654,7 +3034,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2667,7 +3047,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{38} } // State represents a daemon state entry @@ -2680,7 +3060,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2692,7 +3072,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2705,7 +3085,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *State) GetName() string { @@ -2724,7 +3104,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2736,7 +3116,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2749,7 +3129,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{40} } // ListStatesResponse contains a list of states @@ -2762,7 +3142,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2774,7 +3154,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2787,7 +3167,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *ListStatesResponse) GetStates() []*State { @@ -2808,7 +3188,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2820,7 +3200,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2833,7 +3213,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *CleanStateRequest) GetStateName() string { @@ -2860,7 +3240,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2872,7 +3252,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2885,7 +3265,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -2906,7 +3286,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2918,7 +3298,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2931,7 +3311,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *DeleteStateRequest) GetStateName() string { @@ -2958,7 +3338,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2970,7 +3350,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2983,7 +3363,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{45} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3002,7 +3382,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3014,7 +3394,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3027,7 +3407,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3045,7 +3425,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3057,7 +3437,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3070,7 +3450,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{47} } type TCPFlags struct { @@ -3087,7 +3467,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3099,7 +3479,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3112,7 +3492,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TCPFlags) GetSyn() bool { @@ -3174,7 +3554,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3186,7 +3566,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3199,7 +3579,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3277,7 +3657,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3289,7 +3669,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3302,7 +3682,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{50} } func (x *TraceStage) GetName() string { @@ -3343,7 +3723,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3355,7 +3735,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3368,7 +3748,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3393,7 +3773,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3405,7 +3785,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3418,7 +3798,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{52} } type SystemEvent struct { @@ -3436,7 +3816,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3448,7 +3828,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3461,7 +3841,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *SystemEvent) GetId() string { @@ -3521,7 +3901,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3533,7 +3913,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3546,7 +3926,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{54} } type GetEventsResponse struct { @@ -3558,7 +3938,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3570,7 +3950,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3583,7 +3963,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{55} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -3603,7 +3983,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3615,7 +3995,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3628,7 +4008,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -3653,7 +4033,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3665,7 +4045,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3678,7 +4058,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{57} } type SetConfigRequest struct { @@ -3711,16 +4091,22 @@ type SetConfigRequest struct { ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"` // cleanDNSLabels clean map list of DNS labels. - CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` - DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` - Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` + DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,29,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3732,7 +4118,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3745,7 +4131,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *SetConfigRequest) GetUsername() string { @@ -3944,6 +4330,48 @@ func (x *SetConfigRequest) GetMtu() int64 { return 0 } +func (x *SetConfigRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *SetConfigRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -3952,7 +4380,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3964,7 +4392,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3977,7 +4405,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{59} } type AddProfileRequest struct { @@ -3990,7 +4418,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4002,7 +4430,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4015,7 +4443,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *AddProfileRequest) GetUsername() string { @@ -4040,7 +4468,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4052,7 +4480,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4065,7 +4493,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{61} } type RemoveProfileRequest struct { @@ -4078,7 +4506,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4090,7 +4518,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4103,7 +4531,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4128,7 +4556,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4140,7 +4568,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4153,7 +4581,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{63} } type ListProfilesRequest struct { @@ -4165,7 +4593,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4177,7 +4605,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4190,7 +4618,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *ListProfilesRequest) GetUsername() string { @@ -4209,7 +4637,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4221,7 +4649,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4234,7 +4662,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{65} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4254,7 +4682,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4266,7 +4694,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4279,7 +4707,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *Profile) GetName() string { @@ -4304,7 +4732,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4316,7 +4744,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4329,7 +4757,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{67} } type GetActiveProfileResponse struct { @@ -4342,7 +4770,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4354,7 +4782,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4367,7 +4795,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{68} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4394,7 +4822,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4406,7 +4834,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4419,7 +4847,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{69} } func (x *LogoutRequest) GetProfileName() string { @@ -4444,7 +4872,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4456,7 +4884,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4469,7 +4897,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{70} } type GetFeaturesRequest struct { @@ -4480,7 +4908,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4492,7 +4920,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[71] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4505,7 +4933,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{71} } type GetFeaturesResponse struct { @@ -4518,7 +4946,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4530,7 +4958,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[72] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4543,7 +4971,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{72} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -4560,6 +4988,478 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +type GetPeerSSHHostKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // peer IP address or FQDN to get SSH host key for + PeerAddress string `protobuf:"bytes,1,opt,name=peerAddress,proto3" json:"peerAddress,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyRequest) Reset() { + *x = GetPeerSSHHostKeyRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { + if x != nil { + return x.PeerAddress + } + return "" +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +type GetPeerSSHHostKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + SshHostKey []byte `protobuf:"bytes,1,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` + // peer IP address + PeerIP string `protobuf:"bytes,2,opt,name=peerIP,proto3" json:"peerIP,omitempty"` + // peer FQDN + PeerFQDN string `protobuf:"bytes,3,opt,name=peerFQDN,proto3" json:"peerFQDN,omitempty"` + // indicates if the SSH host key was found + Found bool `protobuf:"varint,4,opt,name=found,proto3" json:"found,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyResponse) Reset() { + *x = GetPeerSSHHostKeyResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerIP() string { + if x != nil { + return x.PeerIP + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerFQDN() string { + if x != nil { + return x.PeerFQDN + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetFound() bool { + if x != nil { + return x.Found + } + return false +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +type RequestJWTAuthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // hint for OIDC login_hint parameter (typically email address) + Hint *string `protobuf:"bytes,1,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthRequest) Reset() { + *x = RequestJWTAuthRequest{} + mi := &file_daemon_proto_msgTypes[75] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthRequest) ProtoMessage() {} + +func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[75] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{75} +} + +func (x *RequestJWTAuthRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + +// RequestJWTAuthResponse contains authentication flow information +type RequestJWTAuthResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // verification URI for user authentication + VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"` + // complete verification URI (with embedded user code) + VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"` + // user code to enter on verification URI + UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"` + // device code for polling + DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + // if a cached token is available, it will be returned here + CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"` + // maximum age of JWT tokens in seconds (from management server) + MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthResponse) Reset() { + *x = RequestJWTAuthResponse{} + mi := &file_daemon_proto_msgTypes[76] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthResponse) ProtoMessage() {} + +func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[76] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{76} +} + +func (x *RequestJWTAuthResponse) GetVerificationURI() string { + if x != nil { + return x.VerificationURI + } + return "" +} + +func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string { + if x != nil { + return x.VerificationURIComplete + } + return "" +} + +func (x *RequestJWTAuthResponse) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +func (x *RequestJWTAuthResponse) GetCachedToken() string { + if x != nil { + return x.CachedToken + } + return "" +} + +func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + +// WaitJWTTokenRequest for waiting for authentication completion +type WaitJWTTokenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // device code from RequestJWTAuthResponse + DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // user code for verification + UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenRequest) Reset() { + *x = WaitJWTTokenRequest{} + mi := &file_daemon_proto_msgTypes[77] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenRequest) ProtoMessage() {} + +func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[77] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{77} +} + +func (x *WaitJWTTokenRequest) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *WaitJWTTokenRequest) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +// WaitJWTTokenResponse contains the JWT token after authentication +type WaitJWTTokenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // JWT token (access token or ID token) + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + // token type (e.g., "Bearer") + TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenResponse) Reset() { + *x = WaitJWTTokenResponse{} + mi := &file_daemon_proto_msgTypes[78] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenResponse) ProtoMessage() {} + +func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[78] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{78} +} + +func (x *WaitJWTTokenResponse) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *WaitJWTTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +type InstallerResultRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultRequest) Reset() { + *x = InstallerResultRequest{} + mi := &file_daemon_proto_msgTypes[79] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultRequest) ProtoMessage() {} + +func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[79] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. +func (*InstallerResultRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{79} +} + +type InstallerResultResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultResponse) Reset() { + *x = InstallerResultResponse{} + mi := &file_daemon_proto_msgTypes[80] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultResponse) ProtoMessage() {} + +func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[80] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. +func (*InstallerResultResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{80} +} + +func (x *InstallerResultResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *InstallerResultResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -4570,7 +5470,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4582,7 +5482,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4595,7 +5495,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26, 0} + return file_daemon_proto_rawDescGZIP(), []int{30, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -4617,7 +5517,15 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\x7f\n" + + "\x12OSLifecycleRequest\x128\n" + + "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" + + "\tCycleType\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\t\n" + + "\x05SLEEP\x10\x01\x12\n" + + "\n" + + "\x06WAKEUP\x10\x02\"\x15\n" + + "\x13OSLifecycleResponse\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4654,7 +5562,14 @@ const file_daemon_proto_rawDesc = "" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\" \x01(\bH\x15R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18# \x01(\bH\x16R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4674,7 +5589,14 @@ const file_daemon_proto_rawDesc = "" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\x04_mtuB\a\n" + + "\x05_hintB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -4684,12 +5606,16 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"p\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + + "\n" + + "autoUpdate\x18\x03 \x01(\bH\x02R\n" + + "autoUpdate\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_username\"\f\n" + + "\t_usernameB\r\n" + + "\v_autoUpdate\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -4707,7 +5633,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xe0\x06\n" + + "\busername\x18\x02 \x01(\tR\busername\"\x86\t\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -4732,8 +5658,14 @@ const file_daemon_proto_rawDesc = "" + "disableDns\x122\n" + "\x15disable_client_routes\x18\x12 \x01(\bR\x13disableClientRoutes\x122\n" + "\x15disable_server_routes\x18\x13 \x01(\bR\x13disableServerRoutes\x12(\n" + - "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12)\n" + - "\x10disable_firewall\x18\x15 \x01(\bR\x0fdisableFirewall\"\xde\x05\n" + + "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12$\n" + + "\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" + + "\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" + + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" + + "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" + + "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12)\n" + + "\x10disable_firewall\x18\x1b \x01(\bR\x0fdisableFirewall\"\xfe\x05\n" + "\tPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + @@ -4754,7 +5686,10 @@ const file_daemon_proto_rawDesc = "" + "\x10rosenpassEnabled\x18\x0f \x01(\bR\x10rosenpassEnabled\x12\x1a\n" + "\bnetworks\x18\x10 \x03(\tR\bnetworks\x123\n" + "\alatency\x18\x11 \x01(\v2\x19.google.protobuf.DurationR\alatency\x12\"\n" + - "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\"\xf0\x01\n" + + "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x13 \x01(\fR\n" + + "sshHostKey\"\xf0\x01\n" + "\x0eLocalPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" + @@ -4780,7 +5715,15 @@ const file_daemon_proto_rawDesc = "" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + - "\x05error\x18\x04 \x01(\tR\x05error\"\xef\x03\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + + "\x0eSSHSessionInfo\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12$\n" + + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + + "\x0eSSHServerState\x12\x18\n" + + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + "\n" + "FullStatus\x12A\n" + "\x0fmanagementState\x18\x01 \x01(\v2\x17.daemon.ManagementStateR\x0fmanagementState\x125\n" + @@ -4792,7 +5735,9 @@ const file_daemon_proto_rawDesc = "" + "dnsServers\x128\n" + "\x17NumberOfForwardingRules\x18\b \x01(\x05R\x17NumberOfForwardingRules\x12+\n" + "\x06events\x18\a \x03(\v2\x13.daemon.SystemEventR\x06events\x124\n" + - "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\"\x15\n" + + "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\x12>\n" + + "\x0esshServerState\x18\n" + + " \x01(\v2\x16.daemon.SSHServerStateR\x0esshServerState\"\x15\n" + "\x13ListNetworksRequest\"?\n" + "\x14ListNetworksResponse\x12'\n" + "\x06routes\x18\x01 \x03(\v2\x0f.daemon.NetworkR\x06routes\"a\n" + @@ -4933,7 +5878,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\x8e\r\n" + + "\x15SwitchProfileResponse\"\xdf\x10\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -4966,7 +5911,13 @@ const file_daemon_proto_rawDesc = "" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4984,7 +5935,13 @@ const file_daemon_proto_rawDesc = "" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x13\n" + "\x11_dnsRouteIntervalB\x06\n" + - "\x04_mtu\"\x13\n" + + "\x04_mtuB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -5014,7 +5971,42 @@ const file_daemon_proto_rawDesc = "" + "\x12GetFeaturesRequest\"x\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x18GetPeerSSHHostKeyRequest\x12 \n" + + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x01 \x01(\fR\n" + + "sshHostKey\x12\x16\n" + + "\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" + + "\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" + + "\x05found\x18\x04 \x01(\bR\x05found\"9\n" + + "\x15RequestJWTAuthRequest\x12\x17\n" + + "\x04hint\x18\x01 \x01(\tH\x00R\x04hint\x88\x01\x01B\a\n" + + "\x05_hint\"\x9a\x02\n" + + "\x16RequestJWTAuthResponse\x12(\n" + + "\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" + + "\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" + + "\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" + + "\n" + + "deviceCode\x18\x04 \x01(\tR\n" + + "deviceCode\x12\x1c\n" + + "\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" + + "\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" + + "\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" + + "\x13WaitJWTTokenRequest\x12\x1e\n" + + "\n" + + "deviceCode\x18\x01 \x01(\tR\n" + + "deviceCode\x12\x1a\n" + + "\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" + + "\x14WaitJWTTokenResponse\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" + + "\x16InstallerResultRequest\"O\n" + + "\x17InstallerResultResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -5023,7 +6015,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x8f\x10\n" + + "\x05TRACE\x10\a2\xb4\x13\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5055,7 +6047,12 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5069,180 +6066,206 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel - (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 2: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 3: daemon.EmptyRequest - (*LoginRequest)(nil), // 4: daemon.LoginRequest - (*LoginResponse)(nil), // 5: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 6: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 7: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 8: daemon.UpRequest - (*UpResponse)(nil), // 9: daemon.UpResponse - (*StatusRequest)(nil), // 10: daemon.StatusRequest - (*StatusResponse)(nil), // 11: daemon.StatusResponse - (*DownRequest)(nil), // 12: daemon.DownRequest - (*DownResponse)(nil), // 13: daemon.DownResponse - (*GetConfigRequest)(nil), // 14: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 15: daemon.GetConfigResponse - (*PeerState)(nil), // 16: daemon.PeerState - (*LocalPeerState)(nil), // 17: daemon.LocalPeerState - (*SignalState)(nil), // 18: daemon.SignalState - (*ManagementState)(nil), // 19: daemon.ManagementState - (*RelayState)(nil), // 20: daemon.RelayState - (*NSGroupState)(nil), // 21: daemon.NSGroupState - (*FullStatus)(nil), // 22: daemon.FullStatus - (*ListNetworksRequest)(nil), // 23: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 24: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 25: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 26: daemon.SelectNetworksResponse - (*IPList)(nil), // 27: daemon.IPList - (*Network)(nil), // 28: daemon.Network - (*PortInfo)(nil), // 29: daemon.PortInfo - (*ForwardingRule)(nil), // 30: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 31: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 32: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 33: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 34: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 35: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 36: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 37: daemon.SetLogLevelResponse - (*State)(nil), // 38: daemon.State - (*ListStatesRequest)(nil), // 39: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 40: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 41: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 42: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 43: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 44: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 45: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 46: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 47: daemon.TCPFlags - (*TracePacketRequest)(nil), // 48: daemon.TracePacketRequest - (*TraceStage)(nil), // 49: daemon.TraceStage - (*TracePacketResponse)(nil), // 50: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 51: daemon.SubscribeRequest - (*SystemEvent)(nil), // 52: daemon.SystemEvent - (*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 55: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 56: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 57: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 58: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 59: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 60: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 61: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 62: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 63: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 64: daemon.ListProfilesResponse - (*Profile)(nil), // 65: daemon.Profile - (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 68: daemon.LogoutRequest - (*LogoutResponse)(nil), // 69: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse - nil, // 72: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range - nil, // 74: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 75: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp + (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType + (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 4: daemon.EmptyRequest + (*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest + (*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse + (*LoginRequest)(nil), // 7: daemon.LoginRequest + (*LoginResponse)(nil), // 8: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 11: daemon.UpRequest + (*UpResponse)(nil), // 12: daemon.UpResponse + (*StatusRequest)(nil), // 13: daemon.StatusRequest + (*StatusResponse)(nil), // 14: daemon.StatusResponse + (*DownRequest)(nil), // 15: daemon.DownRequest + (*DownResponse)(nil), // 16: daemon.DownResponse + (*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse + (*PeerState)(nil), // 19: daemon.PeerState + (*LocalPeerState)(nil), // 20: daemon.LocalPeerState + (*SignalState)(nil), // 21: daemon.SignalState + (*ManagementState)(nil), // 22: daemon.ManagementState + (*RelayState)(nil), // 23: daemon.RelayState + (*NSGroupState)(nil), // 24: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 26: daemon.SSHServerState + (*FullStatus)(nil), // 27: daemon.FullStatus + (*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse + (*IPList)(nil), // 32: daemon.IPList + (*Network)(nil), // 33: daemon.Network + (*PortInfo)(nil), // 34: daemon.PortInfo + (*ForwardingRule)(nil), // 35: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse + (*State)(nil), // 43: daemon.State + (*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 52: daemon.TCPFlags + (*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest + (*TraceStage)(nil), // 54: daemon.TraceStage + (*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest + (*SystemEvent)(nil), // 57: daemon.SystemEvent + (*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse + (*Profile)(nil), // 70: daemon.Profile + (*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 73: daemon.LogoutRequest + (*LogoutResponse)(nil), // 74: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse + (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse + (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse + nil, // 85: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range + nil, // 87: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 88: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 16, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState - 20, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState - 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 18: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 19: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 38, // 20: daemon.ListStatesResponse.states:type_name -> daemon.State - 47, // 21: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 32: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 33: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 34: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 35: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 36: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 23, // 37: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 25, // 38: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 25, // 39: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 40: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 32, // 41: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 34, // 42: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 36, // 43: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 39, // 44: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 41, // 45: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 43, // 46: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 45, // 47: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 48, // 48: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 51, // 49: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 53, // 50: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 55, // 51: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 57, // 52: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 59, // 53: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 59, // [59:87] is the sub-list for method output_type - 31, // [31:59] is the sub-list for method input_type - 31, // [31:31] is the sub-list for extension type_name - 31, // [31:31] is the sub-list for extension extendee - 0, // [0:31] is the sub-list for field type_name + 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType + 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState + 23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState + 24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State + 52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 67, // [67:100] is the sub-list for method output_type + 34, // [34:67] is the sub-list for method input_type + 34, // [34:34] is the sub-list for extension type_name + 34, // [34:34] is the sub-list for extension extendee + 0, // [0:34] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -5250,25 +6273,26 @@ func file_daemon_proto_init() { if File_daemon_proto != nil { return } - file_daemon_proto_msgTypes[1].OneofWrappers = []any{} - file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[3].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[26].OneofWrappers = []any{ + file_daemon_proto_msgTypes[9].OneofWrappers = []any{} + file_daemon_proto_msgTypes[30].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[45].OneofWrappers = []any{} - file_daemon_proto_msgTypes[46].OneofWrappers = []any{} - file_daemon_proto_msgTypes[52].OneofWrappers = []any{} - file_daemon_proto_msgTypes[54].OneofWrappers = []any{} - file_daemon_proto_msgTypes[65].OneofWrappers = []any{} + file_daemon_proto_msgTypes[49].OneofWrappers = []any{} + file_daemon_proto_msgTypes[50].OneofWrappers = []any{} + file_daemon_proto_msgTypes[56].OneofWrappers = []any{} + file_daemon_proto_msgTypes[58].OneofWrappers = []any{} + file_daemon_proto_msgTypes[69].OneofWrappers = []any{} + file_daemon_proto_msgTypes[75].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 3, - NumMessages: 72, + NumEnums: 4, + NumMessages: 84, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 3bf86873c..fb34e959d 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -24,7 +24,7 @@ service DaemonService { // Status of the service. rpc Status(StatusRequest) returns (StatusResponse) {} - // Down engine work in the daemon. + // Down stops engine work in the daemon. rpc Down(DownRequest) returns (DownResponse) {} // GetConfig of the daemon. @@ -84,9 +84,37 @@ service DaemonService { rpc Logout(LogoutRequest) returns (LogoutResponse) {} rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} + + // RequestJWTAuth initiates JWT authentication flow for SSH + rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {} + + // WaitJWTToken waits for JWT authentication completion + rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} + + rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} + + rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} } + +message OSLifecycleRequest { + // avoid collision with loglevel enum + enum CycleType { + UNKNOWN = 0; + SLEEP = 1; + WAKEUP = 2; + } + + CycleType type = 1; +} + +message OSLifecycleResponse {} + + message LoginRequest { // setupKey netbird setup key. string setupKey = 1; @@ -158,6 +186,16 @@ message LoginRequest { optional string username = 31; optional int64 mtu = 32; + + // hint is used to pre-fill the email/username field during SSO authentication + optional string hint = 33; + + optional bool enableSSHRoot = 34; + optional bool enableSSHSFTP = 35; + optional bool enableSSHLocalPortForwarding = 36; + optional bool enableSSHRemotePortForwarding = 37; + optional bool disableSSHAuth = 38; + optional int32 sshJWTCacheTTL = 39; } message LoginResponse { @@ -179,15 +217,16 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; + optional bool autoUpdate = 3; } message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; - bool shouldRunProbes = 2; + bool shouldRunProbes = 2; // the UI do not using this yet, but CLIs could use it to wait until the status is ready - optional bool waitForReady = 3; + optional bool waitForReady = 3; } message StatusResponse{ @@ -253,7 +292,19 @@ message GetConfigResponse { bool block_lan_access = 20; - bool disable_firewall = 21; + bool enableSSHRoot = 21; + + bool enableSSHSFTP = 24; + + bool enableSSHLocalPortForwarding = 22; + + bool enableSSHRemotePortForwarding = 23; + + bool disableSSHAuth = 25; + + int32 sshJWTCacheTTL = 26; + + bool disable_firewall = 27; } // PeerState contains the latest state of a peer @@ -275,6 +326,7 @@ message PeerState { repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; + bytes sshHostKey = 19; } // LocalPeerState contains the latest state of the local peer @@ -316,6 +368,20 @@ message NSGroupState { string error = 4; } +// SSHSessionInfo contains information about an active SSH session +message SSHSessionInfo { + string username = 1; + string remoteAddress = 2; + string command = 3; + string jwtUsername = 4; +} + +// SSHServerState contains the latest state of the SSH server +message SSHServerState { + bool enabled = 1; + repeated SSHSessionInfo sessions = 2; +} + // FullStatus contains the full state held by the Status instance message FullStatus { ManagementState managementState = 1; @@ -329,6 +395,7 @@ message FullStatus { repeated SystemEvent events = 7; bool lazyConnectionEnabled = 9; + SSHServerState sshServerState = 10; } // Networks @@ -542,56 +609,63 @@ message SwitchProfileRequest { message SwitchProfileResponse {} message SetConfigRequest { - string username = 1; - string profileName = 2; - // managementUrl to authenticate. - string managementUrl = 3; + string username = 1; + string profileName = 2; + // managementUrl to authenticate. + string managementUrl = 3; - // adminUrl to manage keys. - string adminURL = 4; + // adminUrl to manage keys. + string adminURL = 4; - optional bool rosenpassEnabled = 5; + optional bool rosenpassEnabled = 5; - optional string interfaceName = 6; + optional string interfaceName = 6; - optional int64 wireguardPort = 7; + optional int64 wireguardPort = 7; - optional string optionalPreSharedKey = 8; + optional string optionalPreSharedKey = 8; - optional bool disableAutoConnect = 9; + optional bool disableAutoConnect = 9; - optional bool serverSSHAllowed = 10; + optional bool serverSSHAllowed = 10; - optional bool rosenpassPermissive = 11; + optional bool rosenpassPermissive = 11; - optional bool networkMonitor = 12; + optional bool networkMonitor = 12; - optional bool disable_client_routes = 13; - optional bool disable_server_routes = 14; - optional bool disable_dns = 15; - optional bool disable_firewall = 16; - optional bool block_lan_access = 17; + optional bool disable_client_routes = 13; + optional bool disable_server_routes = 14; + optional bool disable_dns = 15; + optional bool disable_firewall = 16; + optional bool block_lan_access = 17; - optional bool disable_notifications = 18; + optional bool disable_notifications = 18; - optional bool lazyConnectionEnabled = 19; + optional bool lazyConnectionEnabled = 19; - optional bool block_inbound = 20; + optional bool block_inbound = 20; - repeated string natExternalIPs = 21; - bool cleanNATExternalIPs = 22; + repeated string natExternalIPs = 21; + bool cleanNATExternalIPs = 22; - bytes customDNSAddress = 23; + bytes customDNSAddress = 23; - repeated string extraIFaceBlacklist = 24; + repeated string extraIFaceBlacklist = 24; - repeated string dns_labels = 25; - // cleanDNSLabels clean map list of DNS labels. - bool cleanDNSLabels = 26; + repeated string dns_labels = 25; + // cleanDNSLabels clean map list of DNS labels. + bool cleanDNSLabels = 26; - optional google.protobuf.Duration dnsRouteInterval = 27; + optional google.protobuf.Duration dnsRouteInterval = 27; - optional int64 mtu = 28; + optional int64 mtu = 28; + + optional bool enableSSHRoot = 29; + optional bool enableSSHSFTP = 30; + optional bool enableSSHLocalPortForwarding = 31; + optional bool enableSSHRemotePortForwarding = 32; + optional bool disableSSHAuth = 33; + optional int32 sshJWTCacheTTL = 34; } message SetConfigResponse{} @@ -643,3 +717,71 @@ message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; } + +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +message GetPeerSSHHostKeyRequest { + // peer IP address or FQDN to get SSH host key for + string peerAddress = 1; +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +message GetPeerSSHHostKeyResponse { + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + bytes sshHostKey = 1; + // peer IP address + string peerIP = 2; + // peer FQDN + string peerFQDN = 3; + // indicates if the SSH host key was found + bool found = 4; +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +message RequestJWTAuthRequest { + // hint for OIDC login_hint parameter (typically email address) + optional string hint = 1; +} + +// RequestJWTAuthResponse contains authentication flow information +message RequestJWTAuthResponse { + // verification URI for user authentication + string verificationURI = 1; + // complete verification URI (with embedded user code) + string verificationURIComplete = 2; + // user code to enter on verification URI + string userCode = 3; + // device code for polling + string deviceCode = 4; + // expiration time in seconds + int64 expiresIn = 5; + // if a cached token is available, it will be returned here + string cachedToken = 6; + // maximum age of JWT tokens in seconds (from management server) + int64 maxTokenAge = 7; +} + +// WaitJWTTokenRequest for waiting for authentication completion +message WaitJWTTokenRequest { + // device code from RequestJWTAuthResponse + string deviceCode = 1; + // user code for verification + string userCode = 2; +} + +// WaitJWTTokenResponse contains the JWT token after authentication +message WaitJWTTokenResponse { + // JWT token (access token or ID token) + string token = 1; + // token type (e.g., "Bearer") + string tokenType = 2; + // expiration time in seconds + int64 expiresIn = 3; +} + +message InstallerResultRequest { +} + +message InstallerResultResponse { + bool success = 1; + string errorMsg = 2; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index bf7c9c7b3..fdabb1879 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -27,7 +27,7 @@ type DaemonServiceClient interface { Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) // Status of the service. Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) @@ -64,6 +64,14 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) + GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) } type daemonServiceClient struct { @@ -349,6 +357,51 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { + out := new(GetPeerSSHHostKeyResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) { + out := new(RequestJWTAuthResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) { + out := new(WaitJWTTokenResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { + out := new(OSLifecycleResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) { + out := new(InstallerResultResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -362,7 +415,7 @@ type DaemonServiceServer interface { Up(context.Context, *UpRequest) (*UpResponse, error) // Status of the service. Status(context.Context, *StatusRequest) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) @@ -399,6 +452,14 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) + GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -490,6 +551,21 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") +} +func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented") +} +func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") +} +func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") +} +func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1010,6 +1086,96 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetPeerSSHHostKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RequestJWTAuthRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/RequestJWTAuth", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(WaitJWTTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/WaitJWTToken", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(OSLifecycleRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/NotifyOSLifecycle", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InstallerResultRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetInstallerResult", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1125,6 +1291,26 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "GetPeerSSHHostKey", + Handler: _DaemonService_GetPeerSSHHostKey_Handler, + }, + { + MethodName: "RequestJWTAuth", + Handler: _DaemonService_RequestJWTAuth_Handler, + }, + { + MethodName: "WaitJWTToken", + Handler: _DaemonService_WaitJWTToken_Handler, + }, + { + MethodName: "NotifyOSLifecycle", + Handler: _DaemonService_NotifyOSLifecycle_Handler, + }, + { + MethodName: "GetInstallerResult", + Handler: _DaemonService_GetInstallerResult_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/proto/generate.sh b/client/proto/generate.sh index f9a2c3750..e659cef90 100755 --- a/client/proto/generate.sh +++ b/client/proto/generate.sh @@ -14,4 +14,4 @@ cd "$script_path" go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional -cd "$old_pwd" \ No newline at end of file +cd "$old_pwd" diff --git a/client/server/jwt_cache.go b/client/server/jwt_cache.go new file mode 100644 index 000000000..21e170517 --- /dev/null +++ b/client/server/jwt_cache.go @@ -0,0 +1,79 @@ +package server + +import ( + "sync" + "time" + + "github.com/awnumar/memguard" + log "github.com/sirupsen/logrus" +) + +type jwtCache struct { + mu sync.RWMutex + enclave *memguard.Enclave + expiresAt time.Time + timer *time.Timer + maxTokenSize int +} + +func newJWTCache() *jwtCache { + return &jwtCache{ + maxTokenSize: 8192, + } +} + +func (c *jwtCache) store(token string, maxAge time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cleanup() + + if c.timer != nil { + c.timer.Stop() + } + + tokenBytes := []byte(token) + c.enclave = memguard.NewEnclave(tokenBytes) + + c.expiresAt = time.Now().Add(maxAge) + + var timer *time.Timer + timer = time.AfterFunc(maxAge, func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.timer != timer { + return + } + c.cleanup() + c.timer = nil + log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge) + }) + c.timer = timer +} + +func (c *jwtCache) get() (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.enclave == nil || time.Now().After(c.expiresAt) { + return "", false + } + + buffer, err := c.enclave.Open() + if err != nil { + log.Debugf("Failed to open JWT token enclave: %v", err) + return "", false + } + defer buffer.Destroy() + + token := string(buffer.Bytes()) + return token, true +} + +// cleanup destroys the secure enclave, must be called with lock held +func (c *jwtCache) cleanup() { + if c.enclave != nil { + c.enclave = nil + } + c.expiresAt = time.Time{} +} diff --git a/client/server/lifecycle.go b/client/server/lifecycle.go new file mode 100644 index 000000000..3722c027d --- /dev/null +++ b/client/server/lifecycle.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" +) + +// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. +func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { + switch req.GetType() { + case proto.OSLifecycleRequest_WAKEUP: + return s.handleWakeUp(callerCtx) + case proto.OSLifecycleRequest_SLEEP: + return s.handleSleep(callerCtx) + default: + log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) + } + return &proto.OSLifecycleResponse{}, nil +} + +// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep. +// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails. +func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + if !s.sleepTriggeredDown.Load() { + log.Info("skipping up because wasn't sleep down") + return &proto.OSLifecycleResponse{}, nil + } + + // avoid other wakeup runs if sleep didn't make the computer sleep + s.sleepTriggeredDown.Store(false) + + log.Info("running up after wake up") + _, err := s.Up(callerCtx, &proto.UpRequest{}) + if err != nil { + log.Errorf("running up failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + log.Info("running up command executed successfully") + return &proto.OSLifecycleResponse{}, nil +} + +// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state. +func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + s.mutex.Lock() + + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, err + } + + if status != internal.StatusConnecting && status != internal.StatusConnected { + log.Infof("skipping setting the agent down because status is %s", status) + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, nil + } + s.mutex.Unlock() + + log.Info("running down after system started sleeping") + + _, err = s.Down(callerCtx, &proto.DownRequest{}) + if err != nil { + log.Errorf("running down failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + s.sleepTriggeredDown.Store(true) + + log.Info("running down executed successfully") + return &proto.OSLifecycleResponse{}, nil +} diff --git a/client/server/lifecycle_test.go b/client/server/lifecycle_test.go new file mode 100644 index 000000000..a604c60af --- /dev/null +++ b/client/server/lifecycle_test.go @@ -0,0 +1,219 @@ +package server + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" +) + +func newTestServer() *Server { + ctx := internal.CtxInitState(context.Background()) + return &Server{ + rootCtx: ctx, + statusRecorder: peer.NewRecorder(""), + } +} + +func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) { + s := newTestServer() + + // sleepTriggeredDown is false by default + assert.False(t, s.sleepTriggeredDown.Load()) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusIdle) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusNeedsLogin) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnecting) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnected) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected") +} + +func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) { + s := newTestServer() + + // Manually set the flag to simulate prior sleep down + s.sleepTriggeredDown.Store(true) + + // WakeUp will try to call Up which fails without proper setup, but flag should reset first + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt") +} + +func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) { + s := newTestServer() + + // First wakeup without prior sleep - should be no-op + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Simulate prior sleep + s.sleepTriggeredDown.Store(true) + + // First wakeup after sleep - should reset flag + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Second wakeup - should be no-op + resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) +} + +func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) { + s := newTestServer() + + resp, err := s.handleWakeUp(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) +} + +func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) { + s := newTestServer() + s.sleepTriggeredDown.Store(true) + + // Even if Up fails, flag should be reset + _, _ = s.handleWakeUp(context.Background()) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up") +} + +func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Idle", internal.StatusIdle}, + {"NeedsLogin", internal.StatusNeedsLogin}, + {"LoginFailed", internal.StatusLoginFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + resp, err := s.handleSleep(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + }) + } +} + +func TestHandleSleep_ProceedsForActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Connecting", internal.StatusConnecting}, + {"Connected", internal.StatusConnected}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.handleSleep(ctx) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, s.sleepTriggeredDown.Load()) + }) + } +} diff --git a/client/server/network.go b/client/server/network.go index 18b16795d..bb1cce56c 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -11,8 +11,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) type selectRoute struct { diff --git a/client/server/server.go b/client/server/server.go index 052809362..fbb3f0d52 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -46,6 +46,9 @@ const ( defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + // JWT token cache TTL for the client daemon (disabled by default) + defaultJWTCacheTTL = 0 + errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" @@ -81,6 +84,11 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + + // sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down + sleepTriggeredDown atomic.Bool + + jwtCache *jwtCache } type oauthAuthFlow struct { @@ -100,6 +108,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + jwtCache: newJWTCache(), } } @@ -183,7 +192,7 @@ func (s *Server) Start() error { s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -214,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -222,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -251,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) + doInitialAutoUpdate = false if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -353,6 +363,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect @@ -366,6 +383,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.DisableNotifications = msg.DisableNotifications config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + config.EnableSSHRoot = msg.EnableSSHRoot + config.EnableSSHSFTP = msg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding + if msg.DisableSSHAuth != nil { + config.DisableSSHAuth = msg.DisableSSHAuth + } + if msg.SshJWTCacheTTL != nil { + ttl := int(*msg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + } if msg.Mtu != nil { mtu := uint16(*msg.Mtu) @@ -476,13 +504,17 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err } - if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) { + if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { log.Debugf("using previous oauth flow info") return &proto.LoginResponse{ @@ -499,7 +531,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } } - authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) if err != nil { log.Errorf("getting a request OAuth flow failed: %v", err) return nil, err @@ -697,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + var doAutoUpdate bool + if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { + doAutoUpdate = true + } + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) return s.waitForUp(callerCtx) } @@ -791,6 +828,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes defer s.mutex.Unlock() if err := s.cleanupConnection(); err != nil { + // todo review to update the status in case any type of error log.Errorf("failed to shut down properly: %v", err) return nil, err } @@ -883,6 +921,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe } if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { + // todo review to update the status in case any type of error log.Errorf("failed to cleanup connection: %v", err) return nil, err } @@ -1050,20 +1089,240 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() + + pbFullStatus.SshServerState = s.getSSHServerState() + statusResponse.FullStatus = pbFullStatus } return &statusResponse, nil } -func (s *Server) runProbes() { +// getSSHServerState retrieves the current SSH server state including enabled status and active sessions +func (s *Server) getSSHServerState() *proto.SSHServerState { + s.mutex.Lock() + connectClient := s.connectClient + s.mutex.Unlock() + + if connectClient == nil { + return nil + } + + engine := connectClient.Engine() + if engine == nil { + return nil + } + + enabled, sessions := engine.GetSSHServerStatus() + sshServerState := &proto.SSHServerState{ + Enabled: enabled, + } + + for _, session := range sessions { + sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{ + Username: session.Username, + RemoteAddress: session.RemoteAddress, + Command: session.Command, + JwtUsername: session.JWTUsername, + }) + } + + return sshServerState +} + +// GetPeerSSHHostKey retrieves SSH host key for a specific peer +func (s *Server) GetPeerSSHHostKey( + ctx context.Context, + req *proto.GetPeerSSHHostKeyRequest, +) (*proto.GetPeerSSHHostKeyResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + connectClient := s.connectClient + statusRecorder := s.statusRecorder + s.mutex.Unlock() + + if connectClient == nil { + return nil, errors.New("client not initialized") + } + + engine := connectClient.Engine() + if engine == nil { + return nil, errors.New("engine not started") + } + + peerAddress := req.GetPeerAddress() + hostKey, found := engine.GetPeerSSHKey(peerAddress) + + response := &proto.GetPeerSSHHostKeyResponse{ + Found: found, + } + + if !found { + return response, nil + } + + response.SshHostKey = hostKey + + if statusRecorder == nil { + return response, nil + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + response.PeerIP = peerState.IP + response.PeerFQDN = peerState.FQDN + break + } + } + + return response, nil +} + +// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled) +func (s *Server) getJWTCacheTTL() time.Duration { + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil || config.SSHJWTCacheTTL == nil { + return defaultJWTCacheTTL + } + + seconds := *config.SSHJWTCacheTTL + if seconds == 0 { + log.Debug("SSH JWT cache disabled (configured to 0)") + return 0 + } + + ttl := time.Duration(seconds) * time.Second + log.Debugf("SSH JWT cache TTL set to %v from config", ttl) + return ttl +} + +// RequestJWTAuth initiates JWT authentication flow for SSH +func (s *Server) RequestJWTAuth( + ctx context.Context, + msg *proto.RequestJWTAuthRequest, +) (*proto.RequestJWTAuthResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil { + return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured") + } + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + if cachedToken, found := s.jwtCache.get(); found { + log.Debugf("JWT token found in cache, returning cached token for SSH authentication") + + return &proto.RequestJWTAuthResponse{ + CachedToken: cachedToken, + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil + } + } + + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + + if hint == "" { + hint = profilemanager.GetLoginHint() + } + + isDesktop := isUnixRunningDesktop() + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) + } + + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err) + } + + s.mutex.Lock() + s.oauthAuthFlow.flow = oAuthFlow + s.oauthAuthFlow.info = authInfo + s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second) + s.mutex.Unlock() + + return &proto.RequestJWTAuthResponse{ + VerificationURI: authInfo.VerificationURI, + VerificationURIComplete: authInfo.VerificationURIComplete, + UserCode: authInfo.UserCode, + DeviceCode: authInfo.DeviceCode, + ExpiresIn: int64(authInfo.ExpiresIn), + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil +} + +// WaitJWTToken waits for JWT authentication completion +func (s *Server) WaitJWTToken( + ctx context.Context, + req *proto.WaitJWTTokenRequest, +) (*proto.WaitJWTTokenResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + oAuthFlow := s.oauthAuthFlow.flow + authInfo := s.oauthAuthFlow.info + s.mutex.Unlock() + + if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode { + return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow") + } + + tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err) + } + + token := tokenInfo.GetTokenToUse() + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + s.jwtCache.store(token, jwtCacheTTL) + log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL) + } else { + log.Debug("JWT caching disabled, not storing token") + } + + s.mutex.Lock() + s.oauthAuthFlow = oauthAuthFlow{} + s.mutex.Unlock() + return &proto.WaitJWTTokenResponse{ + Token: tokenInfo.GetTokenToUse(), + TokenType: tokenInfo.TokenType, + ExpiresIn: int64(tokenInfo.ExpiresIn), + }, nil +} + +func isUnixRunningDesktop() bool { + if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { + return false + } + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +} + +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1074,7 +1333,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } @@ -1129,26 +1388,62 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p blockLANAccess := cfg.BlockLANAccess disableFirewall := cfg.DisableFirewall + enableSSHRoot := false + if cfg.EnableSSHRoot != nil { + enableSSHRoot = *cfg.EnableSSHRoot + } + + enableSSHSFTP := false + if cfg.EnableSSHSFTP != nil { + enableSSHSFTP = *cfg.EnableSSHSFTP + } + + enableSSHLocalPortForwarding := false + if cfg.EnableSSHLocalPortForwarding != nil { + enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding + } + + enableSSHRemotePortForwarding := false + if cfg.EnableSSHRemotePortForwarding != nil { + enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding + } + + disableSSHAuth := false + if cfg.DisableSSHAuth != nil { + disableSSHAuth = *cfg.DisableSSHAuth + } + + sshJWTCacheTTL := int32(0) + if cfg.SSHJWTCacheTTL != nil { + sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL) + } + return &proto.GetConfigResponse{ - ManagementUrl: managementURL.String(), - PreSharedKey: preSharedKey, - AdminURL: adminURL.String(), - InterfaceName: cfg.WgIface, - WireguardPort: int64(cfg.WgPort), - Mtu: int64(cfg.MTU), - DisableAutoConnect: cfg.DisableAutoConnect, - ServerSSHAllowed: *cfg.ServerSSHAllowed, - RosenpassEnabled: cfg.RosenpassEnabled, - RosenpassPermissive: cfg.RosenpassPermissive, - LazyConnectionEnabled: cfg.LazyConnectionEnabled, - BlockInbound: cfg.BlockInbound, - DisableNotifications: disableNotifications, - NetworkMonitor: networkMonitor, - DisableDns: disableDNS, - DisableClientRoutes: disableClientRoutes, - DisableServerRoutes: disableServerRoutes, - BlockLanAccess: blockLANAccess, - DisableFirewall: disableFirewall, + ManagementUrl: managementURL.String(), + PreSharedKey: preSharedKey, + AdminURL: adminURL.String(), + InterfaceName: cfg.WgIface, + WireguardPort: int64(cfg.WgPort), + Mtu: int64(cfg.MTU), + DisableAutoConnect: cfg.DisableAutoConnect, + ServerSSHAllowed: *cfg.ServerSSHAllowed, + RosenpassEnabled: cfg.RosenpassEnabled, + RosenpassPermissive: cfg.RosenpassPermissive, + LazyConnectionEnabled: cfg.LazyConnectionEnabled, + BlockInbound: cfg.BlockInbound, + DisableNotifications: disableNotifications, + NetworkMonitor: networkMonitor, + DisableDns: disableDNS, + DisableClientRoutes: disableClientRoutes, + DisableServerRoutes: disableServerRoutes, + BlockLanAccess: blockLANAccess, + EnableSSHRoot: enableSSHRoot, + EnableSSHSFTP: enableSSHSFTP, + EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding, + DisableSSHAuth: disableSSHAuth, + SshJWTCacheTTL: sshJWTCacheTTL, + DisableFirewall: disableFirewall, }, nil } @@ -1252,9 +1547,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) if err := s.connectClient.Run(runningChan); err != nil { return err @@ -1379,6 +1674,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { RosenpassEnabled: peerState.RosenpassEnabled, Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), + SshHostKey: peerState.SSHHostKey, } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) } diff --git a/client/server/server_test.go b/client/server/server_test.go index e0a4805f6..69b4453ea 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -15,9 +15,14 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -31,7 +36,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -108,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -290,7 +294,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -311,13 +314,19 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..8e360175d --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,314 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + sshJWTCacheTTL := int32(300) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + SshJWTCacheTTL: &sshJWTCacheTTL, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + require.NotNil(t, cfg.SSHJWTCacheTTL) + require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + "EnableSSHRoot": true, + "EnableSSHSFTP": true, + "EnableSSHLocalPortForwarding": true, + "EnableSSHRemotePortForwarding": true, + "DisableSSHAuth": true, + "SshJWTCacheTTL": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + "enable-ssh-root": "EnableSSHRoot", + "enable-ssh-sftp": "EnableSSHSFTP", + "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding", + "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding", + "disable-ssh-auth": "DisableSSHAuth", + "ssh-jwt-cache-ttl": "SshJWTCacheTTL", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/server/state_generic.go b/client/server/state_generic.go index e6c7bdd44..980ba0cda 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -6,9 +6,11 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 087628907..019477d8e 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { @@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&nftables.ShutdownState{}) mgr.RegisterState(&iptables.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/server/updateresult.go b/client/server/updateresult.go new file mode 100644 index 000000000..8e00d5062 --- /dev/null +++ b/client/server/updateresult.go @@ -0,0 +1,30 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/proto" +) + +func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) { + inst := installer.New() + dir := inst.TempDir() + + rh := installer.NewResultHandler(dir) + result, err := rh.Watch(ctx) + if err != nil { + log.Errorf("failed to watch update result: %v", err) + return &proto.InstallerResultResponse{ + Success: false, + ErrorMsg: err.Error(), + }, nil + } + + return &proto.InstallerResultResponse{ + Success: result.Success, + ErrorMsg: result.Error, + }, nil +} diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go new file mode 100644 index 000000000..488b6e12e --- /dev/null +++ b/client/ssh/auth/auth.go @@ -0,0 +1,184 @@ +package auth + +import ( + "errors" + "fmt" + "sync" + + log "github.com/sirupsen/logrus" + + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +const ( + // DefaultUserIDClaim is the default JWT claim used to extract user IDs + DefaultUserIDClaim = "sub" + // Wildcard is a special user ID that matches all users + Wildcard = "*" +) + +var ( + ErrEmptyUserID = errors.New("JWT user ID is empty") + ErrUserNotAuthorized = errors.New("user is not authorized to access this peer") + ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user") + ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user") +) + +// Authorizer handles SSH fine-grained access control authorization +type Authorizer struct { + // UserIDClaim is the JWT claim to extract the user ID from + userIDClaim string + + // authorizedUsers is a list of hashed user IDs authorized to access this peer + authorizedUsers []sshuserhash.UserIDHash + + // machineUsers maps OS login usernames to lists of authorized user indexes + machineUsers map[string][]uint32 + + // mu protects the list of users + mu sync.RWMutex +} + +// Config contains configuration for the SSH authorizer +type Config struct { + // UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email") + UserIDClaim string + + // AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer + AuthorizedUsers []sshuserhash.UserIDHash + + // MachineUsers maps OS login usernames to indexes in AuthorizedUsers + // If a user wants to login as a specific OS user, their index must be in the corresponding list + MachineUsers map[string][]uint32 +} + +// NewAuthorizer creates a new SSH authorizer with empty configuration +func NewAuthorizer() *Authorizer { + a := &Authorizer{ + userIDClaim: DefaultUserIDClaim, + machineUsers: make(map[string][]uint32), + } + + return a +} + +// Update updates the authorizer configuration with new values +func (a *Authorizer) Update(config *Config) { + a.mu.Lock() + defer a.mu.Unlock() + + if config == nil { + // Clear authorization + a.userIDClaim = DefaultUserIDClaim + a.authorizedUsers = []sshuserhash.UserIDHash{} + a.machineUsers = make(map[string][]uint32) + log.Info("SSH authorization cleared") + return + } + + userIDClaim := config.UserIDClaim + if userIDClaim == "" { + userIDClaim = DefaultUserIDClaim + } + a.userIDClaim = userIDClaim + + // Store authorized users list + a.authorizedUsers = config.AuthorizedUsers + + // Store machine users mapping + machineUsers := make(map[string][]uint32) + for osUser, indexes := range config.MachineUsers { + if len(indexes) > 0 { + machineUsers[osUser] = indexes + } + } + a.machineUsers = machineUsers + + log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings", + len(config.AuthorizedUsers), len(machineUsers)) +} + +// Authorize validates if a user is authorized to login as the specified OS user +// Returns nil if authorized, or an error describing why authorization failed +func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { + if jwtUserID == "" { + log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername) + return ErrEmptyUserID + } + + // Hash the JWT user ID for comparison + hashedUserID, err := sshuserhash.HashUserID(jwtUserID) + if err != nil { + log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err) + return fmt.Errorf("failed to hash user ID: %w", err) + } + + a.mu.RLock() + defer a.mu.RUnlock() + + // Find the index of this user in the authorized list + userIndex, found := a.findUserIndex(hashedUserID) + if !found { + log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername) + return ErrUserNotAuthorized + } + + return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex) +} + +// checkMachineUserMapping validates if a user's index is authorized for the specified OS user +// Checks wildcard mapping first, then specific OS user mappings +func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error { + // If wildcard exists and user's index is in the wildcard list, allow access to any OS user + if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard { + if a.isIndexInList(uint32(userIndex), wildcardIndexes) { + log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex) + return nil + } + } + + // Check for specific OS username mapping + allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername] + if !hasMachineUserMapping { + // No mapping for this OS user - deny by default (fail closed) + log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID) + return ErrNoMachineUserMapping + } + + // Check if user's index is in the allowed indexes for this specific OS user + if !a.isIndexInList(uint32(userIndex), allowedIndexes) { + log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex) + return ErrUserNotMappedToOSUser + } + + log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex) + return nil +} + +// GetUserIDClaim returns the JWT claim name used to extract user IDs +func (a *Authorizer) GetUserIDClaim() string { + a.mu.RLock() + defer a.mu.RUnlock() + return a.userIDClaim +} + +// findUserIndex finds the index of a hashed user ID in the authorized users list +// Returns the index and true if found, 0 and false if not found +func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) { + for i, id := range a.authorizedUsers { + if id == hashedUserID { + return i, true + } + } + return 0, false +} + +// isIndexInList checks if an index exists in a list of indexes +func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool { + for _, idx := range indexes { + if idx == index { + return true + } + } + return false +} diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go new file mode 100644 index 000000000..2b3b5a414 --- /dev/null +++ b/client/ssh/auth/auth_test.go @@ -0,0 +1,612 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/sshauth" +) + +func TestAuthorizer_Authorize_UserNotInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list with one user + authorizedUserHash, err := sshauth.HashUserID("authorized-user") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Try to authorize a different user + err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed) + } + authorizer.Update(config) + + // All attempts should fail when no machine user mappings exist (fail closed) + err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should access root and admin + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user1", "admin") + assert.NoError(t, err) + + // user2 (index 1) should access root and postgres + err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user3 (index 2) should access postgres + err = authorizer.Authorize("user3", "postgres") + assert.NoError(t, err) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should NOT access postgres + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 (index 1) should NOT access admin + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access root + err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access admin + err = authorizer.Authorize("user3", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) +} + +func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only root is mapped + }, + } + authorizer.Update(config) + + // user1 should NOT access an unmapped OS user (fail closed) + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Empty user ID should fail + err = authorizer.Authorize("", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrEmptyUserID) +} + +func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up multiple authorized users + userHashes := make([]sshauth.UserIDHash, 10) + for i := 0; i < 10; i++ { + hash, err := sshauth.HashUserID("user" + string(rune('0'+i))) + require.NoError(t, err) + userHashes[i] = hash + } + + // Create machine user mapping for all users + rootIndexes := make([]uint32, 10) + for i := 0; i < 10; i++ { + rootIndexes[i] = uint32(i) + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": rootIndexes, + }, + } + authorizer.Update(config) + + // All users should be authorized for root + for i := 0; i < 10; i++ { + err := authorizer.Authorize("user"+string(rune('0'+i)), "root") + assert.NoError(t, err, "user%d should be authorized", i) + } + + // User not in list should fail + err := authorizer.Authorize("unknown-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up initial configuration + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{"root": {0}}, + } + authorizer.Update(config) + + // user1 should be authorized + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Clear configuration + authorizer.Update(nil) + + // user1 should no longer be authorized + err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Machine users with empty index lists should be filtered out + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + "postgres": {}, // empty list - should be filtered out + "admin": nil, // nil list - should be filtered out + }, + } + authorizer.Update(config) + + // root should work + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // postgres should fail (no mapping) + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // admin should fail (no mapping) + err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_CustomUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up with custom user ID claim + user1Hash, err := sshauth.HashUserID("user@example.com") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "email", + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + }, + } + authorizer.Update(config) + + // Verify the custom claim is set + assert.Equal(t, "email", authorizer.GetUserIDClaim()) + + // Authorize with email as user ID + err = authorizer.Authorize("user@example.com", "root") + assert.NoError(t, err) +} + +func TestAuthorizer_DefaultUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Verify default claim + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) + assert.Equal(t, "sub", authorizer.GetUserIDClaim()) + + // Set up with empty user ID claim (should use default) + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "", // empty - should use default + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Should fall back to default + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) +} + +func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) { + authorizer := NewAuthorizer() + + // Create a large authorized users list + const numUsers = 1000 + userHashes := make([]sshauth.UserIDHash, numUsers) + for i := 0; i < numUsers; i++ { + hash, err := sshauth.HashUserID("user" + string(rune(i))) + require.NoError(t, err) + userHashes[i] = hash + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": {0, 500, 999}, // first, middle, and last user + }, + } + authorizer.Update(config) + + // First user should have access + err := authorizer.Authorize("user"+string(rune(0)), "root") + assert.NoError(t, err) + + // Middle user should have access + err = authorizer.Authorize("user"+string(rune(500)), "root") + assert.NoError(t, err) + + // Last user should have access + err = authorizer.Authorize("user"+string(rune(999)), "root") + assert.NoError(t, err) + + // User not in mapping should NOT have access + err = authorizer.Authorize("user"+string(rune(100)), "root") + assert.Error(t, err) +} + +func TestAuthorizer_ConcurrentAuthorization(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, + }, + } + authorizer.Update(config) + + // Test concurrent authorization calls (should be safe to read concurrently) + const numGoroutines = 100 + errChan := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + user := "user1" + if idx%2 == 0 { + user = "user2" + } + err := authorizer.Authorize(user, "root") + errChan <- err + }(i) + } + + // Wait for all goroutines to complete and collect errors + for i := 0; i < numGoroutines; i++ { + err := <-errChan + assert.NoError(t, err) + } +} + +func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + // Configure with wildcard - all authorized users can access any OS user + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1, 2}, // wildcard with all user indexes + }, + } + authorizer.Update(config) + + // All authorized users should be able to access any OS user + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + err = authorizer.Authorize("user3", "admin") + assert.NoError(t, err) + + err = authorizer.Authorize("user1", "ubuntu") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "nginx") + assert.NoError(t, err) + + err = authorizer.Authorize("user3", "docker") + assert.NoError(t, err) +} + +func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Configure with wildcard + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, + }, + } + authorizer.Update(config) + + // user1 should have access + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Unauthorized user should still be denied even with wildcard + err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with both wildcard and specific mappings + // Wildcard takes precedence for users in the wildcard index list + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1}, // wildcard for both users + "root": {0}, // specific mapping that would normally restrict to user1 only + }, + } + authorizer.Update(config) + + // Both users should be able to access root via wildcard (takes precedence over specific mapping) + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + // Both users should be able to access any other OS user via wildcard + err = authorizer.Authorize("user1", "postgres") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "admin") + assert.NoError(t, err) +} + +func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure WITHOUT wildcard - only specific mappings + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only user1 + "postgres": {1}, // only user2 + }, + } + authorizer.Update(config) + + // user1 can access root + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // user2 can access postgres + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user1 cannot access postgres + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 cannot access root + err = authorizer.Authorize("user2", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // Neither can access unmapped OS users + err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) { + // This test covers the scenario where wildcard exists with limited indexes. + // Only users whose indexes are in the wildcard list can access any OS user via wildcard. + // Other users can only access OS users they are explicitly mapped to. + authorizer := NewAuthorizer() + + // Create two authorized user hashes (simulating the base64-encoded hashes in the config) + wasmHash, err := sshauth.HashUserID("wasm") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with wildcard having only index 0, and specific mappings for other OS users + config := &Config{ + UserIDClaim: "sub", + AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, // wildcard with only index 0 - only wasm has wildcard access + "alice": {1}, // specific mapping for user2 + "bob": {1}, // specific mapping for user2 + }, + } + authorizer.Update(config) + + // wasm (index 0) should access any OS user via wildcard + err = authorizer.Authorize("wasm", "root") + assert.NoError(t, err, "wasm should access root via wildcard") + + err = authorizer.Authorize("wasm", "alice") + assert.NoError(t, err, "wasm should access alice via wildcard") + + err = authorizer.Authorize("wasm", "bob") + assert.NoError(t, err, "wasm should access bob via wildcard") + + err = authorizer.Authorize("wasm", "postgres") + assert.NoError(t, err, "wasm should access postgres via wildcard") + + // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres + err = authorizer.Authorize("user2", "alice") + assert.NoError(t, err, "user2 should access alice via explicit mapping") + + err = authorizer.Authorize("user2", "bob") + assert.NoError(t, err, "user2 should access bob via explicit mapping") + + err = authorizer.Authorize("user2", "root") + assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "postgres") + assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // Unauthorized user should still be denied + err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied") +} diff --git a/client/ssh/client.go b/client/ssh/client.go deleted file mode 100644 index afba347f8..000000000 --- a/client/ssh/client.go +++ /dev/null @@ -1,118 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "os" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/term" -) - -// Client wraps crypto/ssh Client to simplify usage -type Client struct { - client *ssh.Client -} - -// Close closes the wrapped SSH Client -func (c *Client) Close() error { - return c.client.Close() -} - -// OpenTerminal starts an interactive terminal session with the remote SSH server -func (c *Client) OpenTerminal() error { - session, err := c.client.NewSession() - if err != nil { - return fmt.Errorf("failed to open new session: %v", err) - } - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - fd := int(os.Stdout.Fd()) - state, err := term.MakeRaw(fd) - if err != nil { - return fmt.Errorf("failed to run raw terminal: %s", err) - } - defer func() { - err := term.Restore(fd, state) - if err != nil { - return - } - }() - - w, h, err := term.GetSize(fd) - if err != nil { - return fmt.Errorf("terminal get size: %s", err) - } - - modes := ssh.TerminalModes{ - ssh.ECHO: 1, - ssh.TTY_OP_ISPEED: 14400, - ssh.TTY_OP_OSPEED: 14400, - } - - terminal := os.Getenv("TERM") - if terminal == "" { - terminal = "xterm-256color" - } - if err := session.RequestPty(terminal, h, w, modes); err != nil { - return fmt.Errorf("failed requesting pty session with xterm: %s", err) - } - - session.Stdout = os.Stdout - session.Stderr = os.Stderr - session.Stdin = os.Stdin - - if err := session.Shell(); err != nil { - return fmt.Errorf("failed to start login shell on the remote host: %s", err) - } - - if err := session.Wait(); err != nil { - if e, ok := err.(*ssh.ExitError); ok { - if e.ExitStatus() == 130 { - return nil - } - } - return fmt.Errorf("failed running SSH session: %s", err) - } - - return nil -} - -// DialWithKey connects to the remote SSH server with a provided private key file (PEM). -func DialWithKey(addr, user string, privateKey []byte) (*Client, error) { - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - config := &ssh.ClientConfig{ - User: user, - Timeout: 5 * time.Second, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), - } - - return Dial("tcp", addr, config) -} - -// Dial connects to the remote SSH server. -func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) { - client, err := ssh.Dial(network, addr, config) - if err != nil { - return nil, err - } - return &Client{ - client: client, - }, nil -} diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go new file mode 100644 index 000000000..aab222093 --- /dev/null +++ b/client/ssh/client/client.go @@ -0,0 +1,710 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" + "golang.org/x/term" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util" +) + +const ( + // DefaultDaemonAddr is the default address for the NetBird daemon + DefaultDaemonAddr = "unix:///var/run/netbird.sock" + // DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows + DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731" +) + +// Client wraps crypto/ssh Client for simplified SSH operations +type Client struct { + client *ssh.Client + terminalState *term.State + terminalFd int + + windowsStdoutMode uint32 // nolint:unused + windowsStdinMode uint32 // nolint:unused +} + +func (c *Client) Close() error { + return c.client.Close() +} + +func (c *Client) OpenTerminal(ctx context.Context) error { + session, err := c.client.NewSession() + if err != nil { + return fmt.Errorf("new session: %w", err) + } + defer func() { + if err := session.Close(); err != nil { + log.Debugf("session close error: %v", err) + } + }() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return err + } + + c.setupSessionIO(session) + + if err := session.Shell(); err != nil { + return fmt.Errorf("start shell: %w", err) + } + + return c.waitForSession(ctx, session) +} + +// setupSessionIO connects session streams to local terminal +func (c *Client) setupSessionIO(session *ssh.Session) { + session.Stdout = os.Stdout + session.Stderr = os.Stderr + session.Stdin = os.Stdin +} + +// waitForSession waits for the session to complete with context cancellation +func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error { + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + defer c.restoreTerminal() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return c.handleSessionError(err) + } +} + +// handleSessionError processes session termination errors +func (c *Client) handleSessionError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return fmt.Errorf("session wait: %w", err) + } + + return nil +} + +// restoreTerminal restores the terminal to its original state +func (c *Client) restoreTerminal() { + if c.terminalState != nil { + _ = term.Restore(c.terminalFd, c.terminalState) + c.terminalState = nil + c.terminalFd = 0 + } + + if err := c.restoreWindowsConsoleState(); err != nil { + log.Debugf("restore Windows console state: %v", err) + } +} + +// ExecuteCommand executes a command on the remote host and returns the output +func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return nil, err + } + defer cleanup() + + output, err := session.CombinedOutput(command) + if err != nil { + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return output, fmt.Errorf("execute command: %w", err) + } + } + + return output, nil +} + +// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal +func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions +func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return fmt.Errorf("setup terminal mode: %w", err) + } + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + defer c.restoreTerminal() + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// handleCommandError processes command execution errors +func (c *Client) handleCommandError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if errors.As(err, &e) || errors.As(err, &em) { + return err + } + + return fmt.Errorf("execute command: %w", err) +} + +// setupContextCancellation sets up context cancellation for a session +func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + _ = session.Close() + case <-done: + } + }() + return func() { close(done) } +} + +// createSession creates a new SSH session with context cancellation setup +func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) { + session, err := c.client.NewSession() + if err != nil { + return nil, nil, fmt.Errorf("new session: %w", err) + } + + cancel := c.setupContextCancellation(ctx, session) + cleanup := func() { + cancel() + _ = session.Close() + } + + return session, cleanup, nil +} + +// getDefaultDaemonAddr returns the daemon address from environment or default for the OS +func getDefaultDaemonAddr() string { + if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" { + return addr + } + if runtime.GOOS == "windows" { + return DefaultDaemonAddrWindows + } + return DefaultDaemonAddr +} + +// DialOptions contains options for SSH connections +type DialOptions struct { + KnownHostsFile string + IdentityFile string + DaemonAddr string + SkipCachedToken bool + InsecureSkipVerify bool + NoBrowser bool +} + +// Dial connects to the given ssh server with specified options +func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) { + daemonAddr := opts.DaemonAddr + if daemonAddr == "" { + daemonAddr = getDefaultDaemonAddr() + } + opts.DaemonAddr = daemonAddr + + hostKeyCallback, err := createHostKeyCallback(opts) + if err != nil { + return nil, fmt.Errorf("create host key callback: %w", err) + } + + config := &ssh.ClientConfig{ + User: user, + Timeout: 30 * time.Second, + HostKeyCallback: hostKeyCallback, + } + + if opts.IdentityFile != "" { + authMethod, err := createSSHKeyAuth(opts.IdentityFile) + if err != nil { + return nil, fmt.Errorf("create SSH key auth: %w", err) + } + config.Auth = append(config.Auth, authMethod) + } + + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser) +} + +// dialSSH establishes an SSH connection without JWT authentication +func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("connection close after handshake failure: %v", closeErr) + } + return nil, fmt.Errorf("ssh handshake: %w", err) + } + + client := ssh.NewClient(clientConn, chans, reqs) + return &Client{ + client: client, + }, nil +} + +// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("parse address %s: %w", addr, err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("parse port %s: %w", portStr, err) + } + + detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port) + if err != nil { + return nil, fmt.Errorf("SSH server detection: %w", err) + } + + if !serverType.RequiresJWT() { + return dialSSH(ctx, network, addr, config) + } + + jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser) + if err != nil { + return nil, fmt.Errorf("request JWT token: %w", err) + } + + configWithJWT := nbssh.AddJWTAuth(config, jwtToken) + return dialSSH(ctx, network, addr, configWithJWT) +} + +// requestJWTToken requests a JWT token from the NetBird daemon +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) { + hint := profilemanager.GetLoginHint() + + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return "", fmt.Errorf("connect to daemon: %w", err) + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener) +} + +// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon +func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("daemon connection close error: %v", err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + verifier := nbssh.NewDaemonHostKeyVerifier(client) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} + +func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) { + addr := strings.TrimPrefix(daemonAddr, "tcp://") + + conn, err := grpc.NewClient( + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err) + return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err) + } + + return conn, nil +} + +// getKnownHostsFiles returns paths to known_hosts files in order of preference +func getKnownHostsFiles() []string { + var files []string + + // User's known_hosts file (highest priority) + if homeDir, err := os.UserHomeDir(); err == nil { + userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts") + files = append(files, userKnownHosts) + } + + // NetBird managed known_hosts files + if runtime.GOOS == "windows" { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird") + files = append(files, netbirdKnownHosts) + } else { + files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird") + files = append(files, "/etc/ssh/ssh_known_hosts") + } + + return files +} + +// createHostKeyCallback creates a host key verification callback +func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) { + if opts.InsecureSkipVerify { + return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil { + return nil + } + return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile) + }, nil +} + +func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + if daemonAddr == "" { + return fmt.Errorf("no daemon address") + } + return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr) +} + +func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error { + knownHostsFiles := getKnownHostsFilesList(knownHostsFile) + hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles) + + for _, callback := range hostKeyCallbacks { + if err := callback(hostname, remote, key); err == nil { + return nil + } + } + return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname) +} + +func getKnownHostsFilesList(knownHostsFile string) []string { + if knownHostsFile != "" { + return []string{knownHostsFile} + } + return getKnownHostsFiles() +} + +func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback { + var hostKeyCallbacks []ssh.HostKeyCallback + for _, file := range knownHostsFiles { + if callback, err := knownhosts.New(file); err == nil { + hostKeyCallbacks = append(hostKeyCallbacks, callback) + } + } + return hostKeyCallbacks +} + +// createSSHKeyAuth creates SSH key authentication from a private key file +func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) { + keyData, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err) + } + + signer, err := ssh.ParsePrivateKey(keyData) + if err != nil { + return nil, fmt.Errorf("parse SSH private key: %w", err) + } + + return ssh.PublicKeys(signer), nil +} + +// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr +func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error { + localListener, err := net.Listen("tcp", localAddr) + if err != nil { + return fmt.Errorf("listen on %s: %w", localAddr, err) + } + + go func() { + defer func() { + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + }() + for { + localConn, err := localListener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + continue + } + + go c.handleLocalForward(localConn, remoteAddr) + } + }() + + <-ctx.Done() + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + return ctx.Err() +} + +// handleLocalForward handles a single local port forwarding connection +func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + channel, err := c.client.Dial("tcp", remoteAddr) + if err != nil { + if strings.Contains(err.Error(), "administratively prohibited") { + _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") + } else { + log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) + } + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("local forward copy error (local->remote): %v", err) + } + }() + + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("local forward copy error (remote->local): %v", err) + } +} + +// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr +func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error { + host, port, err := c.parseRemoteAddress(remoteAddr) + if err != nil { + return fmt.Errorf("parse remote address: %w", err) + } + + req := c.buildTCPIPForwardRequest(host, port) + if err := c.sendTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("setup remote forward: %w", err) + } + + go c.handleRemoteForwardChannels(ctx, localAddr) + + <-ctx.Done() + + if err := c.cancelTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("cancel tcpip-forward: %w", err) + } + return ctx.Err() +} + +// parseRemoteAddress parses host and port from remote address string +func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) { + host, portStr, err := net.SplitHostPort(remoteAddr) + if err != nil { + return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err) + } + + return host, uint32(port), nil +} + +// buildTCPIPForwardRequest creates a tcpip-forward request message +func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg { + return tcpipForwardMsg{ + Host: host, + Port: port, + } +} + +// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding +func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error { + ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send tcpip-forward request: %w", err) + } + if !ok { + return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") + } + return nil +} + +// cancelTCPIPForwardRequest cancels the tcpip-forward request +func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error { + _, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send cancel-tcpip-forward request: %w", err) + } + return nil +} + +// handleRemoteForwardChannels handles incoming forwarded-tcpip channels +func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) { + // Get the channel once - subsequent calls return nil! + channelRequests := c.client.HandleChannelOpen("forwarded-tcpip") + if channelRequests == nil { + log.Debugf("forwarded-tcpip channel type already being handled") + return + } + + for { + select { + case <-ctx.Done(): + return + case newChan := <-channelRequests: + if newChan != nil { + go c.handleRemoteForwardChannel(newChan, localAddr) + } + } + } +} + +// handleRemoteForwardChannel handles a single forwarded-tcpip channel +func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { + channel, reqs, err := newChan.Accept() + if err != nil { + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go ssh.DiscardRequests(reqs) + + localConn, err := net.Dial("tcp", localAddr) + if err != nil { + return + } + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("remote forward copy error (remote->local): %v", err) + } + }() + + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("remote forward copy error (local->remote): %v", err) + } +} + +// tcpipForwardMsg represents the structure for tcpip-forward requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go new file mode 100644 index 000000000..e38e02a86 --- /dev/null +++ b/client/ssh/client/client_test.go @@ -0,0 +1,512 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHClient_DialWithKey(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create and start server + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Test Dial + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + // Verify client is connected + assert.NotNil(t, client.client) +} + +func TestSSHClient_CommandExecution(t *testing.T) { + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues") + } + + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + t.Run("ExecuteCommand captures output", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo hello") + assert.NoError(t, err) + assert.Contains(t, string(output), "hello") + }) + + t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) { + err := client.ExecuteCommandWithIO(ctx, "echo world") + assert.NoError(t, err) + }) + + t.Run("commands with flags work", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo -n test_flag") + assert.NoError(t, err) + assert.Equal(t, "test_flag", strings.TrimSpace(string(output))) + }) + + t.Run("non-zero exit codes don't return errors", func(t *testing.T) { + var testCmd string + if runtime.GOOS == "windows" { + testCmd = "echo hello | Select-String notfound" + } else { + testCmd = "echo 'hello' | grep 'notfound'" + } + _, err := client.ExecuteCommand(ctx, testCmd) + assert.NoError(t, err) + }) +} + +func TestSSHClient_ConnectionHandling(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate client key for multiple connections + + const numClients = 3 + clients := make([]*Client, numClients) + + currentUser := testutil.GetTestUsername(t) + for i := 0; i < numClients; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + cancel() + require.NoError(t, err, "Client %d should connect successfully", i) + clients[i] = client + } + + for i, client := range clients { + err := client.Close() + assert.NoError(t, err, "Client %d should close without error", i) + } +} + +func TestSSHClient_ContextCancellation(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + t.Run("connection with short timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + _, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + // Check for actual timeout-related errors rather than string matching + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "timeout"), + "Expected timeout-related error, got: %v", err) + } + }) + + t.Run("command execution cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cmdCancel() + + err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10") + if err != nil { + var exitMissingErr *cryptossh.ExitMissingError + isValidCancellation := errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + errors.As(err, &exitMissingErr) + assert.True(t, isValidCancellation, "Should handle command cancellation properly") + } + }) +} + +func TestSSHClient_NoAuthMode(t *testing.T) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + + t.Run("any key succeeds in no-auth mode", func(t *testing.T) { + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + assert.NoError(t, err) + if client != nil { + require.NoError(t, client.Close(), "Client should close without error") + } + }) +} + +func TestSSHClient_TerminalState(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + assert.Nil(t, client.terminalState) + assert.Equal(t, 0, client.terminalFd) + + client.restoreTerminal() + assert.Nil(t, client.terminalState) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := client.OpenTerminal(ctx) + // In test environment without a real terminal, this may complete quickly or timeout + // Both behaviors are acceptable for testing terminal state management + if err != nil { + if runtime.GOOS == "windows" { + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "console"), + "Should timeout or have console error on Windows") + } else { + // On Unix systems in test environment, we may get various errors + // including timeouts or terminal-related errors + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "terminal") || + strings.Contains(err.Error(), "pty"), + "Expected timeout or terminal-related error, got: %v", err) + } + } +} + +func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + + return server, serverAddr, client +} + +func TestSSHClient_PortForwarding(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + t.Run("local forwarding times out gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "connection"), + "Expected context or connection error") + }) + + t.Run("remote forwarding denied", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + strings.Contains(err.Error(), "denied") || + strings.Contains(err.Error(), "disabled"), + "Should be denied by default") + }) + + t.Run("invalid addresses fail", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080") + assert.Error(t, err) + + err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address") + assert.Error(t, err) + }) +} + +func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { + if testing.Short() { + t.Skip("Skipping data transfer test in short mode") + } + + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowLocalPortForwarding(true) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Port forwarding requires the actual current user, not test user + realUser, err := getRealCurrentUser() + require.NoError(t, err) + + // Skip if running as system account that can't do port forwarding + if testutil.IsSystemAccount(realUser) { + t.Skipf("Skipping port forwarding test - running as system account: %s", realUser) + } + + client, err := Dial(ctx, serverAddr, realUser, DialOptions{ + InsecureSkipVerify: true, // Skip host key verification for test + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + if err := testServer.Close(); err != nil { + t.Logf("test server close error: %v", err) + } + }() + + testServerAddr := testServer.Addr().String() + expectedResponse := "Hello, World!" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("connection read error: %v", err) + return + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("connection write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + if err := localListener.Close(); err != nil { + t.Logf("local listener close error: %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + err := client.LocalPortForward(ctx, localAddr, testServerAddr) + if err != nil && !errors.Is(err, context.Canceled) { + if isWindowsPrivilegeError(err) { + t.Logf("Port forward failed due to Windows privilege restrictions: %v", err) + } else { + t.Logf("Port forward error: %v", err) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second) + require.NoError(t, err) + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + _, err = conn.Write([]byte("test")) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Logf("set read deadline error: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + require.NoError(t, err) + assert.Equal(t, len(expectedResponse), n) + assert.Equal(t, expectedResponse, string(response)) +} + +// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding +func getRealCurrentUser() (string, error) { + if runtime.GOOS == "windows" { + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + } + + if username := os.Getenv("USER"); username != "" { + return username, nil + } + + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + + return "", fmt.Errorf("unable to determine current user") +} + +// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions +func isWindowsPrivilegeError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD + strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess) + strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser) + strings.Contains(errStr, "privilege") || + strings.Contains(errStr, "access denied") || + strings.Contains(errStr, "user authentication failed") +} diff --git a/client/ssh/client/terminal_unix.go b/client/ssh/client/terminal_unix.go new file mode 100644 index 000000000..aaa3418f9 --- /dev/null +++ b/client/ssh/client/terminal_unix.go @@ -0,0 +1,127 @@ +//go:build !windows + +package client + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/term" +) + +func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error { + stdinFd := int(os.Stdin.Fd()) + + if !term.IsTerminal(stdinFd) { + return c.setupNonTerminalMode(ctx, session) + } + + fd := int(os.Stdin.Fd()) + + state, err := term.MakeRaw(fd) + if err != nil { + return c.setupNonTerminalMode(ctx, session) + } + + if err := c.setupTerminal(session, fd); err != nil { + if restoreErr := term.Restore(fd, state); restoreErr != nil { + log.Debugf("restore terminal state: %v", restoreErr) + } + return err + } + + c.terminalState = state + c.terminalFd = fd + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + go func() { + defer signal.Stop(sigChan) + select { + case <-ctx.Done(): + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + case sig := <-sigChan: + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + signal.Reset(sig) + s, ok := sig.(syscall.Signal) + if !ok { + log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig) + return + } + if err := syscall.Kill(syscall.Getpid(), s); err != nil { + log.Debugf("kill process with signal %v: %v", s, err) + } + } + }() + + return nil +} + +func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error { + return nil +} + +// restoreWindowsConsoleState is a no-op on Unix systems +func (c *Client) restoreWindowsConsoleState() error { + return nil +} + +func (c *Client) setupTerminal(session *ssh.Session, fd int) error { + w, h, err := term.GetSize(fd) + if err != nil { + return fmt.Errorf("get terminal size: %w", err) + } + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + // Ctrl+C + ssh.VINTR: 3, + // Ctrl+\ + ssh.VQUIT: 28, + // Backspace + ssh.VERASE: 127, + // Ctrl+U + ssh.VKILL: 21, + // Ctrl+D + ssh.VEOF: 4, + ssh.VEOL: 0, + ssh.VEOL2: 0, + // Ctrl+Q + ssh.VSTART: 17, + // Ctrl+S + ssh.VSTOP: 19, + // Ctrl+Z + ssh.VSUSP: 26, + // Ctrl+O + ssh.VDISCARD: 15, + // Ctrl+R + ssh.VREPRINT: 18, + // Ctrl+W + ssh.VWERASE: 23, + // Ctrl+V + ssh.VLNEXT: 22, + } + + terminal := os.Getenv("TERM") + if terminal == "" { + terminal = "xterm-256color" + } + + if err := session.RequestPty(terminal, h, w, modes); err != nil { + return fmt.Errorf("request pty: %w", err) + } + + return nil +} diff --git a/client/ssh/client/terminal_windows.go b/client/ssh/client/terminal_windows.go new file mode 100644 index 000000000..462438317 --- /dev/null +++ b/client/ssh/client/terminal_windows.go @@ -0,0 +1,265 @@ +package client + +import ( + "context" + "errors" + "fmt" + "os" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +const ( + enableProcessedInput = 0x0001 + enableLineInput = 0x0002 + enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT + enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode) + enableVirtualTerminalInput = 0x0200 +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +// ConsoleUnavailableError indicates that Windows console handles are not available +// (e.g., in CI environments where stdout/stdin are redirected) +type ConsoleUnavailableError struct { + Operation string + Err error +} + +func (e *ConsoleUnavailableError) Error() string { + return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err) +} + +func (e *ConsoleUnavailableError) Unwrap() error { + return e.Err +} + +type coord struct { + x, y int16 +} + +type smallRect struct { + left, top, right, bottom int16 +} + +type consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes uint16 + window smallRect + maximumWindowSize coord +} + +func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error { + if err := c.saveWindowsConsoleState(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("console unavailable, not requesting PTY: %v", err) + return nil + } + return fmt.Errorf("save console state: %w", err) + } + + if err := c.enableWindowsVirtualTerminal(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("virtual terminal unavailable: %v", err) + } else { + return fmt.Errorf("failed to enable virtual terminal: %w", err) + } + } + + w, h := c.getWindowsConsoleSize() + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.ICRNL: 1, + ssh.OPOST: 1, + ssh.ONLCR: 1, + ssh.ISIG: 1, + ssh.ICANON: 1, + ssh.VINTR: 3, // Ctrl+C + ssh.VQUIT: 28, // Ctrl+\ + ssh.VERASE: 127, // Backspace + ssh.VKILL: 21, // Ctrl+U + ssh.VEOF: 4, // Ctrl+D + ssh.VEOL: 0, + ssh.VEOL2: 0, + ssh.VSTART: 17, // Ctrl+Q + ssh.VSTOP: 19, // Ctrl+S + ssh.VSUSP: 26, // Ctrl+Z + ssh.VDISCARD: 15, // Ctrl+O + ssh.VWERASE: 23, // Ctrl+W + ssh.VLNEXT: 22, // Ctrl+V + ssh.VREPRINT: 18, // Ctrl+R + } + + if err := session.RequestPty("xterm-256color", h, w, modes); err != nil { + if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil { + log.Debugf("restore Windows console state: %v", restoreErr) + } + return fmt.Errorf("request pty: %w", err) + } + + return nil +} + +func (c *Client) saveWindowsConsoleState() error { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in saveWindowsConsoleState: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + var stdoutMode, stdinMode uint32 + + ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode))) + if ret == 0 { + log.Debugf("failed to get stdout console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdout console mode", + Err: err, + } + } + + ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode))) + if ret == 0 { + log.Debugf("failed to get stdin console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdin console mode", + Err: err, + } + } + + c.terminalFd = 1 + c.windowsStdoutMode = stdoutMode + c.windowsStdinMode = stdinMode + + log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode) + return nil +} + +func (c *Client) enableWindowsVirtualTerminal() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + var mode uint32 + + ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdout console mode for VT", + Err: winErr, + } + } + + mode |= enableVirtualTerminalProcessing + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "enable virtual terminal processing", + Err: winErr, + } + } + + ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdin console mode for VT", + Err: winErr, + } + } + + mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput) + mode |= enableVirtualTerminalInput + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "set stdin raw mode", + Err: winErr, + } + } + + log.Debugf("enabled Windows virtual terminal processing") + return nil +} + +func (c *Client) getWindowsConsoleSize() (int, int) { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in getWindowsConsoleSize: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + var csbi consoleScreenBufferInfo + + ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi))) + if ret == 0 { + log.Debugf("failed to get console buffer info, using defaults: %v", err) + return 80, 24 + } + + width := int(csbi.window.right - csbi.window.left + 1) + height := int(csbi.window.bottom - csbi.window.top + 1) + + log.Debugf("Windows console size: %dx%d", width, height) + return width, height +} + +func (c *Client) restoreWindowsConsoleState() error { + var err error + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r) + } + }() + + if c.terminalFd != 1 { + return nil + } + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode)) + if ret == 0 { + log.Debugf("failed to restore stdout console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdout console mode: %w", winErr) + } + } + + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode)) + if ret == 0 { + log.Debugf("failed to restore stdin console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdin console mode: %w", winErr) + } + } + + c.terminalFd = 0 + c.windowsStdoutMode = 0 + c.windowsStdinMode = 0 + + log.Debugf("restored Windows console state") + return err +} diff --git a/client/ssh/common.go b/client/ssh/common.go new file mode 100644 index 000000000..6574437b5 --- /dev/null +++ b/client/ssh/common.go @@ -0,0 +1,195 @@ +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/proto" +) + +const ( + NetBirdSSHConfigFile = "99-netbird.conf" + + UnixSSHConfigDir = "/etc/ssh/ssh_config.d" + WindowsSSHConfigDir = "ssh/ssh_config.d" +) + +var ( + // ErrPeerNotFound indicates the peer was not found in the network + ErrPeerNotFound = errors.New("peer not found in network") + // ErrNoStoredKey indicates the peer has no stored SSH host key + ErrNoStoredKey = errors.New("peer has no stored SSH host key") +) + +// HostKeyVerifier provides SSH host key verification +type HostKeyVerifier interface { + VerifySSHHostKey(peerAddress string, key []byte) error +} + +// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon +type DaemonHostKeyVerifier struct { + client proto.DaemonServiceClient +} + +// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier +func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier { + return &DaemonHostKeyVerifier{ + client: client, + } +} + +// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon +func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{ + PeerAddress: peerAddress, + }) + if err != nil { + return err + } + + if !response.GetFound() { + return ErrPeerNotFound + } + + storedKeyData := response.GetSshHostKey() + + return VerifyHostKey(storedKeyData, presentedKey, peerAddress) +} + +// printAuthInstructions prints authentication instructions to stderr +func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.") + _, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:") + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete) + + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") +} + +// RequestJWTToken requests or retrieves a JWT token for SSH authentication +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) { + req := &proto.RequestJWTAuthRequest{} + if hint != "" { + req.Hint = &hint + } + authResponse, err := client.RequestJWTAuth(ctx, req) + if err != nil { + return "", fmt.Errorf("request JWT auth: %w", err) + } + + if useCache && authResponse.CachedToken != "" { + log.Debug("Using cached authentication token") + return authResponse.CachedToken, nil + } + + if stderr != nil { + printAuthInstructions(stderr, authResponse, openBrowser != nil) + } + + if openBrowser != nil { + if err := openBrowser(authResponse.VerificationURIComplete); err != nil { + log.Debugf("open browser: %v", err) + } + } + + tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ + DeviceCode: authResponse.DeviceCode, + UserCode: authResponse.UserCode, + }) + if err != nil { + return "", fmt.Errorf("wait for JWT token: %w", err) + } + + if stdout != nil { + _, _ = fmt.Fprintln(stdout, "Authentication successful!") + } + return tokenResponse.Token, nil +} + +// VerifyHostKey verifies an SSH host key against stored peer key data. +// Returns nil only if the presented key matches the stored key. +// Returns ErrNoStoredKey if storedKeyData is empty. +// Returns an error if the keys don't match or if parsing fails. +func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error { + if len(storedKeyData) == 0 { + return ErrNoStoredKey + } + + storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData) + if err != nil { + return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err) + } + + if !bytes.Equal(presentedKey, storedPubKey.Marshal()) { + return fmt.Errorf("SSH host key mismatch for %s", peerAddress) + } + + return nil +} + +// AddJWTAuth prepends JWT password authentication to existing auth methods. +// This ensures JWT auth is tried first while preserving any existing auth methods. +func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig { + configWithJWT := *config + configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...) + return &configWithJWT +} + +// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier. +// It tries multiple addresses (hostname, IP) for the peer before failing. +func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + addresses := buildAddressList(hostname, remote) + presentedKey := key.Marshal() + + for _, addr := range addresses { + if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil { + if errors.Is(err, ErrPeerNotFound) { + // Try other addresses for this peer + continue + } + return err + } + // Verified + return nil + } + + return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname) + } +} + +// buildAddressList creates a list of addresses to check for host key verification. +// It includes the original hostname and extracts the host part from the remote address if different. +func buildAddressList(hostname string, remote net.Addr) []string { + addresses := []string{hostname} + if host, _, err := net.SplitHostPort(remote.String()); err == nil { + if host != hostname { + addresses = append(addresses, host) + } + } + return addresses +} diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go new file mode 100644 index 000000000..cc47fd2d2 --- /dev/null +++ b/client/ssh/config/manager.go @@ -0,0 +1,277 @@ +package config + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +const ( + EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG" + + EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG" + + MaxPeersForSSHConfig = 200 + + fileWriteTimeout = 2 * time.Second +) + +func isSSHConfigDisabled() bool { + value := os.Getenv(EnvDisableSSHConfig) + if value == "" { + return false + } + + disabled, err := strconv.ParseBool(value) + if err != nil { + return true + } + return disabled +} + +func isSSHConfigForced() bool { + value := os.Getenv(EnvForceSSHConfig) + if value == "" { + return false + } + + forced, err := strconv.ParseBool(value) + if err != nil { + return true + } + return forced +} + +// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count +func shouldGenerateSSHConfig(peerCount int) bool { + if isSSHConfigDisabled() { + return false + } + + if isSSHConfigForced() { + return true + } + + return peerCount <= MaxPeersForSSHConfig +} + +// writeFileWithTimeout writes data to a file with a timeout +func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error { + ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- os.WriteFile(filename, data, perm) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename) + } +} + +// Manager handles SSH client configuration for NetBird peers +type Manager struct { + sshConfigDir string + sshConfigFile string +} + +// PeerSSHInfo represents a peer's SSH configuration information +type PeerSSHInfo struct { + Hostname string + IP string + FQDN string +} + +// New creates a new SSH config manager +func New() *Manager { + sshConfigDir := getSystemSSHConfigDir() + return &Manager{ + sshConfigDir: sshConfigDir, + sshConfigFile: nbssh.NetBirdSSHConfigFile, + } +} + +// getSystemSSHConfigDir returns platform-specific SSH configuration directory +func getSystemSSHConfigDir() string { + if runtime.GOOS == "windows" { + return getWindowsSSHConfigDir() + } + return nbssh.UnixSSHConfigDir +} + +func getWindowsSSHConfigDir() string { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + return filepath.Join(programData, nbssh.WindowsSSHConfigDir) +} + +// SetupSSHClientConfig creates SSH client configuration for NetBird peers +func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error { + if !shouldGenerateSSHConfig(len(peers)) { + m.logSkipReason(len(peers)) + return nil + } + + sshConfig, err := m.buildSSHConfig(peers) + if err != nil { + return fmt.Errorf("build SSH config: %w", err) + } + return m.writeSSHConfig(sshConfig) +} + +func (m *Manager) logSkipReason(peerCount int) { + if isSSHConfigDisabled() { + log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig) + } else { + log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.", + peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig) + } +} + +func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) { + sshConfig := m.buildConfigHeader() + + var allHostPatterns []string + for _, peer := range peers { + hostPatterns := m.buildHostPatterns(peer) + allHostPatterns = append(allHostPatterns, hostPatterns...) + } + + if len(allHostPatterns) > 0 { + peerConfig, err := m.buildPeerConfig(allHostPatterns) + if err != nil { + return "", err + } + sshConfig += peerConfig + } + + return sshConfig, nil +} + +func (m *Manager) buildConfigHeader() string { + return "# NetBird SSH client configuration\n" + + "# Generated automatically - do not edit manually\n" + + "#\n" + + "# To disable SSH config management, use:\n" + + "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" + + "#\n\n" +} + +func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { + uniquePatterns := make(map[string]bool) + var deduplicatedPatterns []string + for _, pattern := range allHostPatterns { + if !uniquePatterns[pattern] { + uniquePatterns[pattern] = true + deduplicatedPatterns = append(deduplicatedPatterns, pattern) + } + } + + execPath, err := m.getNetBirdExecutablePath() + if err != nil { + return "", fmt.Errorf("get NetBird executable path: %w", err) + } + + hostLine := strings.Join(deduplicatedPatterns, " ") + config := fmt.Sprintf("Host %s\n", hostLine) + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" + + if runtime.GOOS == "windows" { + config += " UserKnownHostsFile NUL\n" + } else { + config += " UserKnownHostsFile /dev/null\n" + } + + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" + + return config, nil +} + +func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { + var hostPatterns []string + if peer.IP != "" { + hostPatterns = append(hostPatterns, peer.IP) + } + if peer.FQDN != "" { + hostPatterns = append(hostPatterns, peer.FQDN) + } + if peer.Hostname != "" && peer.Hostname != peer.FQDN { + hostPatterns = append(hostPatterns, peer.Hostname) + } + return hostPatterns +} + +func (m *Manager) writeSSHConfig(sshConfig string) error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + + if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { + return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) + } + + if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { + return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) + } + + log.Infof("Created NetBird SSH client config: %s", sshConfigPath) + return nil +} + +// RemoveSSHClientConfig removes NetBird SSH configuration +func (m *Manager) RemoveSSHClientConfig() error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + err := os.Remove(sshConfigPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err) + } + if err == nil { + log.Infof("Removed NetBird SSH config: %s", sshConfigPath) + } + return nil +} + +func (m *Manager) getNetBirdExecutablePath() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("retrieve executable path: %w", err) + } + + realPath, err := filepath.EvalSymlinks(execPath) + if err != nil { + log.Debugf("symlink resolution failed: %v", err) + return execPath, nil + } + + return realPath, nil +} + +// GetSSHConfigDir returns the SSH config directory path +func (m *Manager) GetSSHConfigDir() string { + return m.sshConfigDir +} + +// GetSSHConfigFile returns the SSH config file name +func (m *Manager) GetSSHConfigFile() string { + return m.sshConfigFile +} diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go new file mode 100644 index 000000000..dc3ad95b3 --- /dev/null +++ b/client/ssh/config/manager_test.go @@ -0,0 +1,159 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestManager_SetupSSHClientConfig(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Test SSH config generation with peers + peers := []PeerSSHInfo{ + { + Hostname: "peer1", + IP: "100.125.1.1", + FQDN: "peer1.nb.internal", + }, + { + Hostname: "peer2", + IP: "100.125.1.2", + FQDN: "peer2.nb.internal", + }, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Read generated config + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + + configStr := string(content) + + // Verify the basic SSH config structure exists + assert.Contains(t, configStr, "# NetBird SSH client configuration") + assert.Contains(t, configStr, "Generated automatically - do not edit manually") + + // Check that peer hostnames are included + assert.Contains(t, configStr, "100.125.1.1") + assert.Contains(t, configStr, "100.125.1.2") + assert.Contains(t, configStr, "peer1.nb.internal") + assert.Contains(t, configStr, "peer2.nb.internal") + + // Check platform-specific UserKnownHostsFile + if runtime.GOOS == "windows" { + assert.Contains(t, configStr, "UserKnownHostsFile NUL") + } else { + assert.Contains(t, configStr, "UserKnownHostsFile /dev/null") + } +} + +func TestGetSystemSSHConfigDir(t *testing.T) { + configDir := getSystemSSHConfigDir() + + // Path should not be empty + assert.NotEmpty(t, configDir) + + // Should be an absolute path + assert.True(t, filepath.IsAbs(configDir)) + + // On Unix systems, should start with /etc + // On Windows, should contain ProgramData + if runtime.GOOS == "windows" { + assert.Contains(t, strings.ToLower(configDir), "programdata") + } else { + assert.Contains(t, configDir, "/etc/ssh") + } +} + +func TestManager_PeerLimit(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is skipped when too many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should not be created due to peer limit + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") +} + +func TestManager_ForcedSSHConfig(t *testing.T) { + // Set force environment variable + t.Setenv(EnvForceSSHConfig, "true") + + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is forced despite many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should be created despite peer limit due to force flag + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + require.NoError(t, err, "SSH config should be created when forced") + + // Verify config contains peer hostnames + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + assert.Contains(t, configStr, "peer0.nb.internal") + assert.Contains(t, configStr, "peer1.nb.internal") +} diff --git a/client/ssh/config/shutdown_state.go b/client/ssh/config/shutdown_state.go new file mode 100644 index 000000000..22f0e0678 --- /dev/null +++ b/client/ssh/config/shutdown_state.go @@ -0,0 +1,22 @@ +package config + +// ShutdownState represents SSH configuration state that needs to be cleaned up. +type ShutdownState struct { + SSHConfigDir string + SSHConfigFile string +} + +// Name returns the state name for the state manager. +func (s *ShutdownState) Name() string { + return "ssh_config_state" +} + +// Cleanup removes SSH client configuration files. +func (s *ShutdownState) Cleanup() error { + manager := &Manager{ + sshConfigDir: s.SSHConfigDir, + sshConfigFile: s.SSHConfigFile, + } + + return manager.RemoveSSHClientConfig() +} diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go new file mode 100644 index 000000000..f23ea4c37 --- /dev/null +++ b/client/ssh/detection/detection.go @@ -0,0 +1,99 @@ +package detection + +import ( + "bufio" + "context" + "fmt" + "net" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // ServerIdentifier is the base response for NetBird SSH servers + ServerIdentifier = "NetBird-SSH-Server" + // ProxyIdentifier is the base response for NetBird SSH proxy + ProxyIdentifier = "NetBird-SSH-Proxy" + // JWTRequiredMarker is appended to responses when JWT is required + JWTRequiredMarker = "NetBird-JWT-Required" + + // DefaultTimeout is the default timeout for SSH server detection + DefaultTimeout = 5 * time.Second +) + +type ServerType string + +const ( + ServerTypeNetBirdJWT ServerType = "netbird-jwt" + ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt" + ServerTypeRegular ServerType = "regular" +) + +// Dialer provides network connection capabilities +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// RequiresJWT checks if the server type requires JWT authentication +func (s ServerType) RequiresJWT() bool { + return s == ServerTypeNetBirdJWT +} + +// ExitCode returns the exit code for the detect command +func (s ServerType) ExitCode() int { + switch s { + case ServerTypeNetBirdJWT: + return 0 + case ServerTypeNetBirdNoJWT: + return 1 + case ServerTypeRegular: + return 2 + default: + return 2 + } +} + +// DetectSSHServerType detects SSH server type using the provided dialer +func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) { + targetAddr := net.JoinHostPort(host, strconv.Itoa(port)) + + conn, err := dialer.DialContext(ctx, "tcp", targetAddr) + if err != nil { + return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err) + } + defer conn.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetReadDeadline(deadline); err != nil { + return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err) + } + } + + reader := bufio.NewReader(conn) + serverBanner, err := reader.ReadString('\n') + if err != nil { + return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err) + } + + serverBanner = strings.TrimSpace(serverBanner) + log.Debugf("SSH server banner: %s", serverBanner) + + if !strings.HasPrefix(serverBanner, "SSH-") { + log.Debugf("Invalid SSH banner") + return ServerTypeRegular, nil + } + + if !strings.Contains(serverBanner, ServerIdentifier) { + log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier) + return ServerTypeRegular, nil + } + + if strings.Contains(serverBanner, JWTRequiredMarker) { + return ServerTypeNetBirdJWT, nil + } + + return ServerTypeNetBirdNoJWT, nil +} diff --git a/client/ssh/login.go b/client/ssh/login.go deleted file mode 100644 index cb2615e55..000000000 --- a/client/ssh/login.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - - "github.com/netbirdio/netbird/util" -) - -func isRoot() bool { - return os.Geteuid() == 0 -} - -func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { - if !isRoot() { - shell := getUserShell(user) - if shell == "" { - shell = "/bin/sh" - } - - return shell, []string{"-l"}, nil - } - - loginPath, err = exec.LookPath("login") - if err != nil { - return "", nil, err - } - - addrPort, err := netip.ParseAddrPort(remoteAddr.String()) - if err != nil { - return "", nil, err - } - - switch runtime.GOOS { - case "linux": - if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") { - return loginPath, []string{"-f", user, "-p"}, nil - } - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - case "darwin": - return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil - case "freebsd": - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - default: - return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) - } -} diff --git a/client/ssh/lookup.go b/client/ssh/lookup.go deleted file mode 100644 index 9a7f6ff2e..000000000 --- a/client/ssh/lookup.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !darwin -// +build !darwin - -package ssh - -import "os/user" - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - return user.Lookup(username) -} diff --git a/client/ssh/lookup_darwin.go b/client/ssh/lookup_darwin.go deleted file mode 100644 index 913d049dc..000000000 --- a/client/ssh/lookup_darwin.go +++ /dev/null @@ -1,51 +0,0 @@ -//go:build darwin -// +build darwin - -package ssh - -import ( - "bytes" - "fmt" - "os/exec" - "os/user" - "strings" -) - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - var userObject *user.User - userObject, err := user.Lookup(username) - if err != nil && err.Error() == user.UnknownUserError(username).Error() { - return idUserNameLookup(username) - } else if err != nil { - return nil, err - } - - return userObject, nil -} - -func idUserNameLookup(username string) (*user.User, error) { - cmd := exec.Command("id", "-P", username) - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err) - } - colon := ":" - - if !bytes.Contains(out, []byte(username+colon)) { - return nil, fmt.Errorf("unable to find user in returned string") - } - // netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh - parts := strings.SplitN(string(out), colon, 10) - userObject := &user.User{ - Username: parts[0], - Uid: parts[2], - Gid: parts[3], - Name: parts[7], - HomeDir: parts[8], - } - return userObject, nil -} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go new file mode 100644 index 000000000..4e807e33c --- /dev/null +++ b/client/ssh/proxy/proxy.go @@ -0,0 +1,394 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/version" +) + +const ( + // sshConnectionTimeout is the timeout for SSH TCP connection establishment + sshConnectionTimeout = 120 * time.Second + // sshHandshakeTimeout is the timeout for SSH handshake completion + sshHandshakeTimeout = 30 * time.Second + + jwtAuthErrorMsg = "JWT authentication: %w" +) + +type SSHProxy struct { + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient + browserOpener func(string) error +} + +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { + grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") + grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, fmt.Errorf("connect to daemon: %w", err) + } + + return &SSHProxy{ + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + browserOpener: browserOpener, + }, nil +} + +func (p *SSHProxy) Close() error { + if p.conn != nil { + return p.conn.Close() + } + return nil +} + +func (p *SSHProxy) Connect(ctx context.Context) error { + hint := profilemanager.GetLoginHint() + + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener) + if err != nil { + return fmt.Errorf(jwtAuthErrorMsg, err) + } + + return p.runProxySSHServer(ctx, jwtToken) +} + +func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { + serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) + + sshServer := &ssh.Server{ + Handler: func(s ssh.Session) { + p.handleSSHSession(ctx, s, jwtToken) + }, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": p.directTCPIPHandler, + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": func(s ssh.Session) { + p.sftpSubsystemHandler(s, jwtToken) + }, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": p.tcpipForwardHandler, + "cancel-tcpip-forward": p.cancelTcpipForwardHandler, + }, + Version: serverVersion, + } + + hostKey, err := generateHostKey() + if err != nil { + return fmt.Errorf("generate host key: %w", err) + } + sshServer.HostSigners = []ssh.Signer{hostKey} + + conn := &stdioConn{ + stdin: os.Stdin, + stdout: os.Stdout, + } + + sshServer.HandleConn(conn) + + return nil +} + +func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) + return + } + defer func() { _ = sshClient.Close() }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) + return + } + defer func() { _ = serverSession.Close() }() + + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() + + ptyReq, winCh, isPty := session.Pty() + if isPty { + if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { + log.Debugf("PTY request to backend: %v", err) + } + + go func() { + for win := range winCh { + if err := serverSession.WindowChange(win.Height, win.Width); err != nil { + log.Debugf("window change: %v", err) + } + } + }() + } + + if len(session.Command()) > 0 { + if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + log.Debugf("run command: %v", err) + p.handleProxyExitCode(session, err) + } + return + } + + if err = serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } + if err := serverSession.Wait(); err != nil { + log.Debugf("session wait: %v", err) + p.handleProxyExitCode(session, err) + } +} + +func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { + var exitErr *cryptossh.ExitError + if errors.As(err, &exitErr) { + if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { + log.Debugf("set exit status: %v", exitErr) + } + } +} + +func generateHostKey() (ssh.Signer, error) { + keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, fmt.Errorf("generate ED25519 key: %w", err) + } + + signer, err := cryptossh.ParsePrivateKey(keyPEM) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + return signer, nil +} + +type stdioConn struct { + stdin io.Reader + stdout io.Writer + closed bool + mu sync.Mutex +} + +func (c *stdioConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.EOF + } + c.mu.Unlock() + return c.stdin.Read(b) +} + +func (c *stdioConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.ErrClosedPipe + } + c.mu.Unlock() + return c.stdout.Write(b) +} + +func (c *stdioConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *stdioConn) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { + _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") +} + +func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := sshClient.Close(); err != nil { + log.Debugf("close SSH client: %v", err) + } + }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(s, "create server session: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := serverSession.Close(); err != nil { + log.Debugf("close server session: %v", err) + } + }() + + stdin, stdout, err := p.setupSFTPPipes(serverSession) + if err != nil { + log.Debugf("setup SFTP pipes: %v", err) + _ = s.Exit(1) + return + } + + if err := serverSession.RequestSubsystem("sftp"); err != nil { + _, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err) + _ = s.Exit(1) + return + } + + p.runSFTPBridge(ctx, s, stdin, stdout, serverSession) +} + +func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) { + stdin, err := serverSession.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdin pipe: %w", err) + } + + stdout, err := serverSession.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdout pipe: %w", err) + } + + return stdin, stdout, nil +} + +func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) { + copyErrCh := make(chan error, 2) + + go func() { + _, err := io.Copy(stdin, s) + if err != nil { + log.Debugf("SFTP client to server copy: %v", err) + } + if err := stdin.Close(); err != nil { + log.Debugf("close stdin: %v", err) + } + copyErrCh <- err + }() + + go func() { + _, err := io.Copy(s, stdout) + if err != nil { + log.Debugf("SFTP server to client copy: %v", err) + } + copyErrCh <- err + }() + + go func() { + <-ctx.Done() + if err := serverSession.Close(); err != nil { + log.Debugf("force close server session on context cancellation: %v", err) + } + }() + + for i := 0; i < 2; i++ { + if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) { + log.Debugf("SFTP copy error: %v", err) + } + } + + if err := serverSession.Wait(); err != nil { + log.Debugf("SFTP session ended: %v", err) + } +} + +func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return false, []byte("port forwarding not supported in proxy") +} + +func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return true, nil +} + +func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { + config := &cryptossh.ClientConfig{ + User: user, + Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)}, + Timeout: sshHandshakeTimeout, + HostKeyCallback: p.verifyHostKey, + } + + dialer := &net.Dialer{ + Timeout: sshConnectionTimeout, + } + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("connect to server: %w", err) + } + + clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config) + if err != nil { + _ = conn.Close() + return nil, fmt.Errorf("SSH handshake: %w", err) + } + + return cryptossh.NewClient(clientConn, chans, reqs), nil +} + +func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error { + verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go new file mode 100644 index 000000000..81d588801 --- /dev/null +++ b/client/ssh/proxy/proxy_test.go @@ -0,0 +1,384 @@ +package proxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" + "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +func TestMain(m *testing.M) { + if len(os.Args) > 2 && os.Args[1] == "ssh" { + if os.Args[2] == "exec" { + if len(os.Args) > 3 { + cmd := os.Args[3] + if cmd == "echo" && len(os.Args) > 4 { + fmt.Fprintln(os.Stdout, os.Args[4]) + os.Exit(0) + } + } + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args) + os.Exit(1) + } + } + + code := m.Run() + + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHProxy_verifyHostKey(t *testing.T) { + t.Run("calls daemon to verify host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + testPubKey, err := nbssh.GeneratePublicKey(testKey) + require.NoError(t, err) + + mockDaemon.setHostKey("test-host", testPubKey) + + err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey)) + assert.NoError(t, err) + }) + + t.Run("rejects unknown host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey) + require.NoError(t, err) + + err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "peer unknown-host not found in network") + }) +} + +func TestSSHProxy_Connect(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // TODO: Windows test times out - user switching and command execution tested on Linux + if runtime.GOOS == "windows" { + t.Skip("Skipping on Windows - covered by Linux tests") + } + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + // Configure SSH authorization for the test user + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, // Index 0 in AuthorizedUsers + }, + } + sshServer.UpdateSSHAuth(authConfig) + + sshServerAddr := server.StartTestServer(t, sshServer) + defer func() { _ = sshServer.Stop() }() + + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) + require.NoError(t, err) + + clientConn, proxyConn := net.Pipe() + defer func() { _ = clientConn.Close() }() + + origStdin := os.Stdin + origStdout := os.Stdout + defer func() { + os.Stdin = origStdin + os.Stdout = origStdout + }() + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + go func() { + _, _ = io.Copy(stdinWriter, proxyConn) + }() + go func() { + _, _ = io.Copy(proxyConn, stdoutReader) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + connectErrCh := make(chan error, 1) + go func() { + connectErrCh <- proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 3 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err, "Should connect to proxy server") + defer func() { _ = sshClientConn.Close() }() + + sshClient := cryptossh.NewClient(sshClientConn, chans, reqs) + + session, err := sshClient.NewSession() + require.NoError(t, err, "Should create session through full proxy to backend") + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output("echo hello-from-proxy") + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + require.NoError(t, err, "Command should execute successfully through proxy") + assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy") + case <-time.After(3 * time.Second): + t.Fatal("Command execution timed out") + } + + _ = session.Close() + _ = sshClient.Close() + _ = clientConn.Close() + cancel() +} + +type mockDaemonServer struct { + proto.UnimplementedDaemonServiceServer + hostKeys map[string][]byte + jwtToken string +} + +func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) { + key, found := m.hostKeys[req.PeerAddress] + return &proto.GetPeerSSHHostKeyResponse{ + Found: found, + SshHostKey: key, + }, nil +} + +func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) { + return &proto.RequestJWTAuthResponse{ + CachedToken: m.jwtToken, + }, nil +} + +func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) { + return &proto.WaitJWTTokenResponse{ + Token: m.jwtToken, + }, nil +} + +type mockDaemon struct { + addr string + server *grpc.Server + impl *mockDaemonServer +} + +func startMockDaemon(t *testing.T) *mockDaemon { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + impl := &mockDaemonServer{ + hostKeys: make(map[string][]byte), + jwtToken: "test-jwt-token", + } + + grpcServer := grpc.NewServer() + proto.RegisterDaemonServiceServer(grpcServer, impl) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return &mockDaemon{ + addr: listener.Addr().String(), + server: grpcServer, + impl: impl, + } +} + +func (m *mockDaemon) setHostKey(addr string, pubKey []byte) { + m.impl.hostKeys[addr] = pubKey +} + +func (m *mockDaemon) setJWTToken(token string) { + m.impl.jwtToken = token +} + +func (m *mockDaemon) stop() { + if m.server != nil { + m.server.Stop() + } +} + +func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey { + t.Helper() + pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes) + require.NoError(t, err) + return pubKey +} + +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + t.Helper() + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(n), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string { + t.Helper() + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": user, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} diff --git a/client/ssh/server.go b/client/ssh/server.go deleted file mode 100644 index 8c5db2547..000000000 --- a/client/ssh/server.go +++ /dev/null @@ -1,280 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "io" - "net" - "os" - "os/exec" - "os/user" - "runtime" - "strings" - "sync" - "time" - - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - log "github.com/sirupsen/logrus" -) - -// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server -const DefaultSSHPort = 44338 - -// TerminalTimeout is the timeout for terminal session to be ready -const TerminalTimeout = 10 * time.Second - -// TerminalBackoffDelay is the delay between terminal session readiness checks -const TerminalBackoffDelay = 500 * time.Millisecond - -// DefaultSSHServer is a function that creates DefaultServer -func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { - return newDefaultServer(hostKeyPEM, addr) -} - -// Server is an interface of SSH server -type Server interface { - // Stop stops SSH server. - Stop() error - // Start starts SSH server. Blocking - Start() error - // RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys - RemoveAuthorizedKey(peer string) - // AddAuthorizedKey add a given peer key to server authorized keys - AddAuthorizedKey(peer, newKey string) error -} - -// DefaultServer is the embedded NetBird SSH server -type DefaultServer struct { - listener net.Listener - // authorizedKeys is ssh pub key indexed by peer WireGuard public key - authorizedKeys map[string]ssh.PublicKey - mu sync.Mutex - hostKeyPEM []byte - sessions []ssh.Session -} - -// newDefaultServer creates new server with provided host key -func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - allowedKeys := make(map[string]ssh.PublicKey) - return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *DefaultServer) RemoveAuthorizedKey(peer string) { - srv.mu.Lock() - defer srv.mu.Unlock() - - delete(srv.authorizedKeys, peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error { - srv.mu.Lock() - defer srv.mu.Unlock() - - parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey)) - if err != nil { - return err - } - - srv.authorizedKeys[peer] = parsedKey - return nil -} - -// Stop stops SSH server. -func (srv *DefaultServer) Stop() error { - srv.mu.Lock() - defer srv.mu.Unlock() - err := srv.listener.Close() - if err != nil { - return err - } - for _, session := range srv.sessions { - err := session.Close() - if err != nil { - log.Warnf("failed closing SSH session from %v", err) - } - } - - return nil -} - -func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { - srv.mu.Lock() - defer srv.mu.Unlock() - - for _, allowed := range srv.authorizedKeys { - if ssh.KeysEqual(allowed, key) { - return true - } - } - - return false -} - -func prepareUserEnv(user *user.User, shell string) []string { - return []string{ - fmt.Sprint("SHELL=" + shell), - fmt.Sprint("USER=" + user.Username), - fmt.Sprint("HOME=" + user.HomeDir), - } -} - -func acceptEnv(s string) bool { - split := strings.Split(s, "=") - if len(split) != 2 { - return false - } - return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_") -} - -// sessionHandler handles SSH session post auth -func (srv *DefaultServer) sessionHandler(session ssh.Session) { - srv.mu.Lock() - srv.sessions = append(srv.sessions, session) - srv.mu.Unlock() - - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String()) - - localUser, err := userNameLookup(session.User()) - if err != nil { - _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint - err = session.Exit(1) - if err != nil { - return - } - log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User()) - return - } - - ptyReq, winCh, isPty := session.Pty() - if isPty { - loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr()) - if err != nil { - log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String()) - return - } - cmd := exec.Command(loginCmd, loginArgs...) - go func() { - <-session.Context().Done() - if cmd.Process == nil { - return - } - err := cmd.Process.Kill() - if err != nil { - log.Debugf("failed killing SSH process %v", err) - return - } - }() - cmd.Dir = localUser.HomeDir - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...) - for _, v := range session.Environ() { - if acceptEnv(v) { - cmd.Env = append(cmd.Env, v) - } - } - - log.Debugf("Login command: %s", cmd.String()) - file, err := pty.Start(cmd) - if err != nil { - log.Errorf("failed starting SSH server: %v", err) - } - - go func() { - for win := range winCh { - setWinSize(file, win.Width, win.Height) - } - }() - - srv.stdInOut(file, session) - - err = cmd.Wait() - if err != nil { - return - } - } else { - _, err := io.WriteString(session, "only PTY is supported.\n") - if err != nil { - return - } - err = session.Exit(1) - if err != nil { - return - } - } - log.Debugf("SSH session ended") -} - -func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { - go func() { - // stdin - _, err := io.Copy(file, session) - if err != nil { - _ = session.Exit(1) - return - } - }() - - // AWS Linux 2 machines need some time to open the terminal so we need to wait for it - timer := time.NewTimer(TerminalTimeout) - for { - select { - case <-timer.C: - _, _ = session.Write([]byte("Reached timeout while opening connection\n")) - _ = session.Exit(1) - return - default: - // stdout - writtenBytes, err := io.Copy(session, file) - if err != nil && writtenBytes != 0 { - _ = session.Exit(0) - return - } - time.Sleep(TerminalBackoffDelay) - } - } -} - -// Start starts SSH server. Blocking -func (srv *DefaultServer) Start() error { - log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String()) - - publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler) - hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM) - err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM) - if err != nil { - return err - } - - return nil -} - -func getUserShell(userID string) string { - if runtime.GOOS == "linux" { - output, _ := exec.Command("getent", "passwd", userID).Output() - line := strings.SplitN(string(output), ":", 10) - if len(line) > 6 { - return strings.TrimSpace(line[6]) - } - } - - shell := os.Getenv("SHELL") - if shell == "" { - shell = "/bin/sh" - } - return shell -} diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go new file mode 100644 index 000000000..7a01ce4f6 --- /dev/null +++ b/client/ssh/server/command_execution.go @@ -0,0 +1,206 @@ +package server + +import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handleCommand executes an SSH command with privilege validation +func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { + hasPty := winCh != nil + + commandType := "command" + if hasPty { + commandType = "Pty command" + } + + logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) + + execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + if err != nil { + logger.Errorf("%s creation failed: %v", commandType, err) + + errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType) + if hasPty { + errorMsg += " with Pty" + } + errorMsg += "\n" + + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return + } + + if !hasPty { + if s.executeCommand(logger, session, execCmd, cleanup) { + logger.Debugf("%s execution completed", commandType) + } + return + } + + defer cleanup() + + ptyReq, _, _ := session.Pty() + if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { + logger.Debugf("%s execution completed", commandType) + } +} + +func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, nil, errors.New("no user in privilege result") + } + + // If PTY requested but su doesn't support --pty, skip su and use executor + // This ensures PTY functionality is provided (executor runs within our allocated PTY) + if hasPty && !s.suSupportsPty { + log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + // Try su first for system integration (PAM/audit) when privileged + cmd, err := s.createSuCommand(session, localUser, hasPty) + if err != nil || privilegeResult.UsedFallback { + log.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, func() {}, nil +} + +// executeCommand executes the command and handles I/O and exit codes +func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool { + defer cleanup() + + s.setupProcessGroup(execCmd) + + stdinPipe, err := execCmd.StdinPipe() + if err != nil { + logger.Errorf("create stdin pipe: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + execCmd.Stdout = session + execCmd.Stderr = session.Stderr() + + if execCmd.Dir != "" { + if _, err := os.Stat(execCmd.Dir); err != nil { + logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err) + execCmd.Dir = "/" + } + } + + if err := execCmd.Start(); err != nil { + logger.Errorf("command start failed: %v", err) + // no user message for exec failure, just exit + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + go s.handleCommandIO(logger, stdinPipe, session) + return s.waitForCommandCleanup(logger, session, execCmd) +} + +// handleCommandIO manages stdin/stdout copying in a goroutine +func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) { + defer func() { + if err := stdinPipe.Close(); err != nil { + logger.Debugf("stdin pipe close error: %v", err) + } + }() + if _, err := io.Copy(stdinPipe, session); err != nil { + logger.Debugf("stdin copy error: %v", err) + } +} + +// waitForCommandCleanup waits for command completion with session disconnect handling +func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + logger.Debugf("session cancelled, terminating command") + s.killProcessGroup(execCmd) + + select { + case err := <-done: + logger.Tracef("command terminated after session cancellation: %v", err) + case <-time.After(5 * time.Second): + logger.Warnf("command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } + return false + + case err := <-done: + return s.handleCommandCompletion(logger, session, err) + } +} + +// handleCommandCompletion handles command completion +func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool { + if err != nil { + logger.Debugf("command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return false + } + + s.handleSessionExit(session, nil, logger) + return true +} + +// handleSessionExit handles command errors and sets appropriate exit codes +func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) { + if err == nil { + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } + return + } + + var exitError *exec.ExitError + if errors.As(err, &exitError) { + if err := session.Exit(exitError.ExitCode()); err != nil { + logSessionExitError(logger, err) + } + } else { + logger.Debugf("non-exit error in command execution: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + } +} diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go new file mode 100644 index 000000000..01759a337 --- /dev/null +++ b/client/ssh/server/command_execution_js.go @@ -0,0 +1,57 @@ +//go:build js + +package server + +import ( + "context" + "errors" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") + +// createSuCommand is not supported on JS/WASM +func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { + return nil, errNotSupported +} + +// createExecutorCommand is not supported on JS/WASM +func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { + return nil, nil, errNotSupported +} + +// prepareCommandEnv is not supported on JS/WASM +func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { + return nil +} + +// setupProcessGroup is not supported on JS/WASM +func (s *Server) setupProcessGroup(_ *exec.Cmd) { +} + +// killProcessGroup is not supported on JS/WASM +func (s *Server) killProcessGroup(*exec.Cmd) { +} + +// detectSuPtySupport always returns false on JS/WASM +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// detectUtilLinuxLogin always returns false on JS/WASM +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + +// executeCommandWithPty is not supported on JS/WASM +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + logger.Errorf("PTY command execution not supported on JS/WASM") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go new file mode 100644 index 000000000..db1a9bcfe --- /dev/null +++ b/client/ssh/server/command_execution_unix.go @@ -0,0 +1,353 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "os/user" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// ptyManager manages Pty file operations with thread safety +type ptyManager struct { + file *os.File + mu sync.RWMutex + closed bool + closeErr error + once sync.Once +} + +func newPtyManager(file *os.File) *ptyManager { + return &ptyManager{file: file} +} + +func (pm *ptyManager) Close() error { + pm.once.Do(func() { + pm.mu.Lock() + pm.closed = true + pm.closeErr = pm.file.Close() + pm.mu.Unlock() + }) + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.closeErr +} + +func (pm *ptyManager) Setsize(ws *pty.Winsize) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + if pm.closed { + return errors.New("pty is closed") + } + return pty.Setsize(pm.file, ws) +} + +func (pm *ptyManager) File() *os.File { + return pm.file +} + +// detectSuPtySupport checks if su supports the --pty flag +func (s *Server) detectSuPtySupport(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "su", "--help") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("su --help failed (may not support --help): %v", err) + return false + } + + supported := strings.Contains(string(output), "--pty") + log.Debugf("su --pty support detected: %v", supported) + return supported +} + +// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils). +// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent. +// See https://bugs.debian.org/1078023 for details. +func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { + if runtime.GOOS != "linux" { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "login", "--version") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("login --version failed (likely shadow-utils): %v", err) + return false + } + + isUtilLinux := strings.Contains(string(output), "util-linux") + log.Debugf("util-linux login detected: %v", isUtilLinux) + return isUtilLinux +} + +// createSuCommand creates a command using su -l -c for privilege switching +func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + suPath, err := exec.LookPath("su") + if err != nil { + return nil, fmt.Errorf("su command not available: %w", err) + } + + command := session.RawCommand() + if command == "" { + return nil, fmt.Errorf("no command specified for su execution") + } + + args := []string{"-l"} + if hasPty && s.suSupportsPty { + args = append(args, "--pty") + } + args = append(args, localUser.Username, "-c", command) + + cmd := exec.CommandContext(session.Context(), suPath, args...) + cmd.Dir = localUser.HomeDir + + return cmd, nil +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-l"} + } + return []string{shell, "-l", "-c", cmdString} +} + +// prepareCommandEnv prepares environment variables for command execution on Unix +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +// executeCommandWithPty executes a command with PTY allocation +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType)) + + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) + if err != nil { + logger.Errorf("Pty command creation failed: %v", err) + errorMsg := "User switching failed - login command not available\r\n" + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " ")) + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +// runPtyCommand runs a command with PTY management (common code for interactive and command execution) +func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) + if err != nil { + logger.Errorf("Pty start failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + ptyMgr := newPtyManager(ptmx) + defer func() { + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close error: %v", err) + } + }() + + go s.handlePtyWindowResize(logger, session, ptyMgr, winCh) + s.handlePtyIO(logger, session, ptyMgr) + s.waitForPtyCompletion(logger, session, execCmd, ptyMgr) + return true +} + +func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) { + winSize := &pty.Winsize{ + Cols: uint16(ptyReq.Window.Width), + Rows: uint16(ptyReq.Window.Height), + } + if winSize.Cols == 0 { + winSize.Cols = 80 + } + if winSize.Rows == 0 { + winSize.Rows = 24 + } + + ptmx, err := pty.StartWithSize(execCmd, winSize) + if err != nil { + return nil, fmt.Errorf("start Pty: %w", err) + } + + return ptmx, nil +} + +func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) { + for { + select { + case <-session.Context().Done(): + return + case win, ok := <-winCh: + if !ok { + return + } + if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil { + logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err) + } + } + } +} + +func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) { + ptmx := ptyMgr.File() + + go func() { + if _, err := io.Copy(ptmx, session); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty input copy error: %v", err) + } + } + }() + + go func() { + defer func() { + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Debugf("session close error: %v", err) + } + }() + if _, err := io.Copy(session, ptmx); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty output copy error: %v", err) + } + } + }() +} + +func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) + case err := <-done: + s.handlePtyCommandCompletion(logger, session, err) + } +} + +func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) { + logger.Debugf("Pty session cancelled, terminating command") + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close during session cancellation: %v", err) + } + + s.killProcessGroup(execCmd) + + select { + case err := <-done: + if err != nil { + logger.Debugf("Pty command terminated after session cancellation with error: %v", err) + } else { + logger.Debugf("Pty command terminated after session cancellation") + } + case <-time.After(5 * time.Second): + logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { + if err != nil { + logger.Debugf("Pty command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return + } + + // Normal completion + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) setupProcessGroup(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + pgid := cmd.Process.Pid + + if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { + logger.Debugf("kill process group SIGTERM: %v", err) + return + } + + const gracePeriod = 500 * time.Millisecond + const checkInterval = 50 * time.Millisecond + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + timeout := time.After(gracePeriod) + + for { + select { + case <-timeout: + if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil { + logger.Debugf("kill process group SIGKILL: %v", err) + } + return + case <-ticker.C: + if err := syscall.Kill(-pgid, 0); err != nil { + return + } + } + } +} diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go new file mode 100644 index 000000000..998796871 --- /dev/null +++ b/client/ssh/server/command_execution_windows.go @@ -0,0 +1,435 @@ +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/ssh/server/winpty" +) + +// getUserEnvironment retrieves the Windows environment for the target user. +// Follows OpenSSH's resilient approach with graceful degradation on failures. +func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { + userToken, err := s.getUserToken(username, domain) + if err != nil { + return nil, fmt.Errorf("get user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + return s.getUserEnvironmentWithToken(userToken, username, domain) +} + +// getUserEnvironmentWithToken retrieves the Windows environment using an existing token. +func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { + userProfile, err := s.loadUserProfile(userToken, username, domain) + if err != nil { + log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + userProfile = fmt.Sprintf("C:\\Users\\%s", username) + } + + envMap := make(map[string]string) + + if err := s.loadSystemEnvironment(envMap); err != nil { + log.Debugf("failed to load system environment from registry: %v", err) + } + + s.setUserEnvironmentVariables(envMap, userProfile, username, domain) + + var env []string + for key, value := range envMap { + env = append(env, key+"="+value) + } + + return env, nil +} + +// getUserToken creates a user token for the specified user. +func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper() + token, err := privilegeDropper.createToken(username, domain) + if err != nil { + return 0, fmt.Errorf("generate S4U user token: %w", err) + } + return token, nil +} + +// loadUserProfile loads the Windows user profile and returns the profile path. +func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) { + usernamePtr, err := windows.UTF16PtrFromString(username) + if err != nil { + return "", fmt.Errorf("convert username to UTF-16: %w", err) + } + + var domainUTF16 *uint16 + if domain != "" && domain != "." { + domainUTF16, err = windows.UTF16PtrFromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to UTF-16: %w", err) + } + } + + type profileInfo struct { + dwSize uint32 + dwFlags uint32 + lpUserName *uint16 + lpProfilePath *uint16 + lpDefaultPath *uint16 + lpServerName *uint16 + lpPolicyPath *uint16 + hProfile windows.Handle + } + + const PI_NOUI = 0x00000001 + + profile := profileInfo{ + dwSize: uint32(unsafe.Sizeof(profileInfo{})), + dwFlags: PI_NOUI, + lpUserName: usernamePtr, + lpServerName: domainUTF16, + } + + userenv := windows.NewLazySystemDLL("userenv.dll") + loadUserProfileW := userenv.NewProc("LoadUserProfileW") + + ret, _, err := loadUserProfileW.Call( + uintptr(userToken), + uintptr(unsafe.Pointer(&profile)), + ) + + if ret == 0 { + return "", fmt.Errorf("LoadUserProfileW: %w", err) + } + + if profile.lpProfilePath == nil { + return "", fmt.Errorf("LoadUserProfileW returned null profile path") + } + + profilePath := windows.UTF16PtrToString(profile.lpProfilePath) + return profilePath, nil +} + +// loadSystemEnvironment loads system-wide environment variables from registry. +func (s *Server) loadSystemEnvironment(envMap map[string]string) error { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, + `SYSTEM\CurrentControlSet\Control\Session Manager\Environment`, + registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("open system environment registry key: %w", err) + } + defer func() { + if err := key.Close(); err != nil { + log.Debugf("close registry key: %v", err) + } + }() + + return s.readRegistryEnvironment(key, envMap) +} + +// readRegistryEnvironment reads environment variables from a registry key. +func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error { + names, err := key.ReadValueNames(0) + if err != nil { + return fmt.Errorf("read registry value names: %w", err) + } + + for _, name := range names { + value, valueType, err := key.GetStringValue(name) + if err != nil { + log.Debugf("failed to read registry value %s: %v", name, err) + continue + } + + finalValue := s.expandRegistryValue(value, valueType, name) + s.setEnvironmentVariable(envMap, name, finalValue) + } + + return nil +} + +// expandRegistryValue expands registry values if they contain environment variables. +func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string { + if valueType != registry.EXPAND_SZ { + return value + } + + sourcePtr := windows.StringToUTF16Ptr(value) + expandedBuffer := make([]uint16, 1024) + expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s: %v", name, err) + return value + } + + // If buffer was too small, retry with larger buffer + if expandedLen > uint32(len(expandedBuffer)) { + expandedBuffer = make([]uint16, expandedLen) + expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s on retry: %v", name, err) + return value + } + } + + if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) { + return windows.UTF16ToString(expandedBuffer[:expandedLen-1]) + } + return value +} + +// setEnvironmentVariable sets an environment variable with special handling for PATH. +func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) { + upperName := strings.ToUpper(name) + + if upperName == "PATH" { + if existing, exists := envMap["PATH"]; exists && existing != value { + envMap["PATH"] = existing + ";" + value + } else { + envMap["PATH"] = value + } + } else { + envMap[upperName] = value + } +} + +// setUserEnvironmentVariables sets critical user-specific environment variables. +func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) { + envMap["USERPROFILE"] = userProfile + + if len(userProfile) >= 2 && userProfile[1] == ':' { + envMap["HOMEDRIVE"] = userProfile[:2] + envMap["HOMEPATH"] = userProfile[2:] + } + + envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming") + envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local") + + tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp") + envMap["TEMP"] = tempDir + envMap["TMP"] = tempDir + + envMap["USERNAME"] = username + if domain != "" && domain != "." { + envMap["USERDOMAIN"] = domain + envMap["USERDNSDOMAIN"] = domain + } + + systemVars := []string{ + "PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION", + "SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT", + "PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC", + } + + for _, sysVar := range systemVars { + if sysValue := os.Getenv(sysVar); sysValue != "" { + envMap[sysVar] = sysValue + } + } +} + +// prepareCommandEnv prepares environment variables for command execution on Windows +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + username, domain := s.parseUsername(localUser.Username) + userEnv, err := s.getUserEnvironment(username, domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env + } + + env := userEnv + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + if privilegeResult.User == nil { + logger.Errorf("no user in privilege result") + return false + } + + cmd := session.Command() + shell := getUserShell(privilegeResult.User.Uid) + + if len(cmd) == 0 { + logger.Infof("starting interactive shell: %s", shell) + } else { + logger.Infof("executing command: %s", safeLogCommand(cmd)) + } + + s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + return true +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-NoLogo"} + } + return []string{shell, "-Command", cmdString} +} + +func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { + logger.Info("starting interactive shell") + s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) +} + +type PtyExecutionRequest struct { + Shell string + Command string + Width int + Height int + Username string + Domain string +} + +func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { + log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", + req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) + + privilegeDropper := NewPrivilegeDropper() + userToken, err := privilegeDropper.createToken(req.Username, req.Domain) + if err != nil { + return fmt.Errorf("create user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + server := &Server{} + userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + userEnv = os.Environ() + } + + workingDir := getUserHomeFromEnv(userEnv) + if workingDir == "" { + workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username) + } + + ptyConfig := winpty.PtyConfig{ + Shell: req.Shell, + Command: req.Command, + Width: req.Width, + Height: req.Height, + WorkingDir: workingDir, + } + + userConfig := winpty.UserConfig{ + Token: userToken, + Environment: userEnv, + } + + log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) +} + +func getUserHomeFromEnv(env []string) string { + for _, envVar := range env { + if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" { + return envVar[12:] + } + } + return "" +} + +func (s *Server) setupProcessGroup(_ *exec.Cmd) { + // Windows doesn't support process groups in the same way as Unix + // Process creation groups are handled differently +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + + if err := cmd.Process.Kill(); err != nil { + logger.Debugf("kill process failed: %v", err) + } +} + +// detectSuPtySupport always returns false on Windows as su is not available +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// detectUtilLinuxLogin always returns false on Windows +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + +// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + command := session.RawCommand() + if command == "" { + logger.Error("no command specified for PTY execution") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) +} + +// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) +func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { + localUser := privilegeResult.User + if localUser == nil { + logger.Errorf("no user in privilege result") + return false + } + + username, domain := s.parseUsername(localUser.Username) + shell := getUserShell(localUser.Uid) + + req := PtyExecutionRequest{ + Shell: shell, + Command: command, + Width: ptyReq.Window.Width, + Height: ptyReq.Window.Height, + Username: username, + Domain: domain, + } + + if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + logger.Errorf("ConPty execution failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Debug("ConPty execution completed") + return true +} diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go new file mode 100644 index 000000000..34ffccfd2 --- /dev/null +++ b/client/ssh/server/compatibility_test.go @@ -0,0 +1,722 @@ +package server + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client +func TestSSHServerCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH compatibility tests in short mode") + } + + // Check if ssh binary is available + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Set up SSH server - use our existing key generation for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate OpenSSH-compatible keys for client + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Create temporary key files for SSH client + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + // Extract host and port from server address + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Get appropriate user for SSH connection (handle system accounts) + username := testutil.GetTestUsername(t) + + t.Run("basic command execution", func(t *testing.T) { + testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username) + }) + + t.Run("interactive command", func(t *testing.T) { + testSSHInteractiveCommand(t, host, portStr, clientKeyFile) + }) + + t.Run("port forwarding", func(t *testing.T) { + testSSHPortForwarding(t, host, portStr, clientKeyFile) + }) +} + +// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user. +func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) { + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "hello_world") + + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("SSH command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully") +} + +// testSSHInteractiveCommand tests interactive shell session. +func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host)) + + stdin, err := cmd.StdinPipe() + if err != nil { + t.Skipf("Cannot create stdin pipe: %v", err) + return + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Skipf("Cannot create stdout pipe: %v", err) + return + } + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH session: %v", err) + return + } + + go func() { + defer func() { + if err := stdin.Close(); err != nil { + t.Logf("stdin close error: %v", err) + } + }() + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("exit\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + }() + + output, err := io.ReadAll(stdout) + if err != nil { + t.Logf("Cannot read SSH output: %v", err) + } + + err = cmd.Wait() + if err != nil { + t.Logf("SSH interactive session error: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work") +} + +// testSSHPortForwarding tests port forwarding compatibility. +func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer testServer.Close() + + testServerAddr := testServer.Addr().String() + expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("test server connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("Test server read error: %v", err) + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("Test server write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + localListener.Close() + + _, localPort, err := net.SplitHostPort(localAddr) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr) + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-L", forwardSpec, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-N", + fmt.Sprintf("%s@%s", username, host)) + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH port forwarding: %v", err) + return + } + + defer func() { + if cmd.Process != nil { + if err := cmd.Process.Kill(); err != nil { + t.Logf("process kill error: %v", err) + } + } + if err := cmd.Wait(); err != nil { + t.Logf("process wait after kill: %v", err) + } + }() + + time.Sleep(500 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second) + if err != nil { + t.Logf("Cannot connect to forwarded port: %v", err) + return + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("forwarded connection close error: %v", err) + } + }() + + request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + _, err = conn.Write([]byte(request)) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + log.Debugf("failed to set read deadline: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + if err != nil { + t.Logf("Cannot read forwarded response: %v", err) + return + } + + assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes") + assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding") +} + +// isSSHClientAvailable checks if the ssh binary is available +func isSSHClientAvailable() bool { + _, err := exec.LookPath("ssh") + return err == nil +} + +// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use. +func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) { + // Check if ssh-keygen is available + if _, err := exec.LookPath("ssh-keygen"); err != nil { + // Fall back to our existing key generation and try to convert + return generateOpenSSHKeyFallback() + } + + // Create temporary file for ssh-keygen + tempFile, err := os.CreateTemp("", "ssh_keygen_*") + if err != nil { + return nil, nil, fmt.Errorf("create temp file: %w", err) + } + keyPath := tempFile.Name() + tempFile.Close() + + // Remove the temp file so ssh-keygen can create it + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to remove key file: %v", err) + } + + // Clean up temp files + defer func() { + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to cleanup key file: %v", err) + } + if err := os.Remove(keyPath + ".pub"); err != nil { + t.Logf("failed to cleanup public key file: %v", err) + } + }() + + // Generate key using ssh-keygen + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output)) + } + + // Read private key + privKeyBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read private key: %w", err) + } + + // Read public key + pubKeyBytes, err := os.ReadFile(keyPath + ".pub") + if err != nil { + return nil, nil, fmt.Errorf("read public key: %w", err) + } + + return privKeyBytes, pubKeyBytes, nil +} + +// generateOpenSSHKeyFallback falls back to generating keys using our existing method +func generateOpenSSHKeyFallback() ([]byte, []byte, error) { + // Generate shared.ED25519 key pair using our existing method + _, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("generate key: %w", err) + } + + // Convert to SSH format + sshPrivKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + return nil, nil, fmt.Errorf("create signer: %w", err) + } + + // For the fallback, just use our PKCS#8 format and hope it works + // This won't be in OpenSSH format but might still work with some SSH clients + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, nil, fmt.Errorf("generate fallback key: %w", err) + } + + // Get public key in SSH format + sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey()) + + return hostKey, sshPubKey, nil +} + +// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes +func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) { + t.Helper() + + tempFile, err := os.CreateTemp("", "ssh_test_key_*") + require.NoError(t, err) + + _, err = tempFile.Write(keyBytes) + require.NoError(t, err) + + err = tempFile.Close() + require.NoError(t, err) + + // Set proper permissions for SSH key (readable by owner only) + err = os.Chmod(tempFile.Name(), 0600) + require.NoError(t, err) + + cleanup := func() { + _ = os.Remove(tempFile.Name()) + } + + return tempFile.Name(), cleanup +} + +// createTempKeyFile creates a temporary SSH private key file (for backward compatibility) +func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { + return createTempKeyFileFromBytes(t, privateKey) +} + +// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility +func TestSSHServerFeatureCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH feature compatibility tests in short mode") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Test various SSH features + testCases := []struct { + name string + testFunc func(t *testing.T, host, port, keyFile string) + description string + }{ + { + name: "command_with_flags", + testFunc: testCommandWithFlags, + description: "Commands with flags should work like standard SSH", + }, + { + name: "environment_variables", + testFunc: testEnvironmentVariables, + description: "Environment variables should be available", + }, + { + name: "exit_codes", + testFunc: testExitCodes, + description: "Exit codes should be properly handled", + }, + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.testFunc(t, host, portStr, clientKeyFile) + }) + } +} + +// testCommandWithFlags tests that commands with flags work properly +func testCommandWithFlags(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test ls with flags + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "ls", "-la", "/tmp") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command with flags failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // Should not be empty and should not contain error messages + assert.NotEmpty(t, string(output), "ls -la should produce output") + assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed") +} + +// testEnvironmentVariables tests that environment is properly set up +func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "$HOME") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Environment test failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // HOME environment variable should be available + homeOutput := strings.TrimSpace(string(output)) + assert.NotEmpty(t, homeOutput, "HOME environment variable should be set") + assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded") +} + +// testExitCodes tests that exit codes are properly handled +func testExitCodes(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test successful command (exit code 0) + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "true") // always succeeds + + err := cmd.Run() + assert.NoError(t, err, "Command with exit code 0 should succeed") + + // Test failing command (exit code 1) + cmd = exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "false") // always fails + + err = cmd.Run() + assert.Error(t, err, "Command with exit code 1 should fail") + + // Check if it's the right kind of error + if exitError, ok := err.(*exec.ExitError); ok { + assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved") + } +} + +// TestSSHServerSecurityFeatures tests security-related SSH features +func TestSSHServerSecurityFeatures(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH security tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server with specific security settings + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + t.Run("key_authentication", func(t *testing.T) { + // Test that key authentication works + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "auth_success") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Key authentication failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "auth_success", "Key authentication should work") + }) + + t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) { + // Create a different key that shouldn't be accepted + wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey) + defer cleanupWrongKey() + + // Test that wrong key is rejected + cmd := exec.Command("ssh", + "-i", wrongKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "should_not_work") + + err = cmd.Run() + assert.NoError(t, err, "Any key should work in no-auth mode") + }) +} + +// TestCrossPlatformCompatibility tests cross-platform behavior +func TestCrossPlatformCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping cross-platform compatibility tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Test platform-specific commands + var testCommand string + + switch runtime.GOOS { + case "windows": + testCommand = "echo %OS%" + default: + testCommand = "uname" + } + + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + testCommand) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Platform-specific command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + outputStr := strings.TrimSpace(string(output)) + t.Logf("Platform command output: %s", outputStr) + assert.NotEmpty(t, outputStr, "Platform-specific command should produce output") +} diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go new file mode 100644 index 000000000..8adc824ef --- /dev/null +++ b/client/ssh/server/executor_unix.go @@ -0,0 +1,253 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// Exit codes for executor process communication +const ( + ExitCodeSuccess = 0 + ExitCodePrivilegeDropFail = 10 + ExitCodeShellExecFail = 11 + ExitCodeValidationFail = 12 +) + +// ExecutorConfig holds configuration for the executor process +type ExecutorConfig struct { + UID uint32 + GID uint32 + Groups []uint32 + WorkingDir string + Shell string + Command string + PTY bool +} + +// PrivilegeDropper handles secure privilege dropping in child processes +type PrivilegeDropper struct{} + +// NewPrivilegeDropper creates a new privilege dropper +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping +func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("get netbird executable path: %w", err) + } + + if err := pd.validatePrivileges(config.UID, config.GID); err != nil { + return nil, fmt.Errorf("invalid privileges: %w", err) + } + + args := []string{ + "ssh", "exec", + "--uid", fmt.Sprintf("%d", config.UID), + "--gid", fmt.Sprintf("%d", config.GID), + "--working-dir", config.WorkingDir, + "--shell", config.Shell, + } + + for _, group := range config.Groups { + args = append(args, "--groups", fmt.Sprintf("%d", group)) + } + + if config.PTY { + args = append(args, "--pty") + } + + if config.Command != "" { + args = append(args, "--cmd", config.Command) + } + + // Log executor args safely - show all args except hide the command value + safeArgs := make([]string, len(args)) + copy(safeArgs, args) + for i := 0; i < len(safeArgs)-1; i++ { + if safeArgs[i] == "--cmd" { + cmdParts := strings.Fields(safeArgs[i+1]) + safeArgs[i+1] = safeLogCommand(cmdParts) + break + } + } + log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + return exec.CommandContext(ctx, netbirdPath, args...), nil +} + +// DropPrivileges performs privilege dropping with thread locking for security +func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + if err := pd.validatePrivileges(targetUID, targetGID); err != nil { + return fmt.Errorf("invalid privileges: %w", err) + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originalUID := os.Geteuid() + originalGID := os.Getegid() + + if originalUID != int(targetUID) || originalGID != int(targetGID) { + if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil { + return fmt.Errorf("set groups and IDs: %w", err) + } + } + + if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID) + return nil +} + +// setGroupsAndIDs sets the supplementary groups, GID, and UID +func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + groups := make([]int, len(supplementaryGroups)) + for i, g := range supplementaryGroups { + groups[i] = int(g) + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" { + if len(groups) == 0 || groups[0] != int(targetGID) { + groups = append([]int{int(targetGID)}, groups...) + } + } + + if err := syscall.Setgroups(groups); err != nil { + return fmt.Errorf("setgroups to %v: %w", groups, err) + } + + if err := syscall.Setgid(int(targetGID)); err != nil { + return fmt.Errorf("setgid to %d: %w", targetGID, err) + } + + if err := syscall.Setuid(int(targetUID)); err != nil { + return fmt.Errorf("setuid to %d: %w", targetUID, err) + } + + return nil +} + +// validatePrivilegeDropSuccess validates that privilege dropping was successful +func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error { + if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil { + return err + } + + return nil +} + +// validatePrivilegeDropReversibility ensures privileges cannot be restored +func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error { + if originalGID != int(targetGID) { + if err := syscall.Setegid(originalGID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID) + } + } + if originalUID != int(targetUID) { + if err := syscall.Seteuid(originalUID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID) + } + } + return nil +} + +// validateCurrentPrivileges validates the current UID and GID match the target +func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error { + currentUID := os.Geteuid() + if currentUID != int(targetUID) { + return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID) + } + + currentGID := os.Getegid() + if currentGID != int(targetGID) { + return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID) + } + + return nil +} + +// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures +func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) { + log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + // TODO: Implement Pty support for executor path + if config.PTY { + config.PTY = false + } + + if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err) + os.Exit(ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err) + } + } + + var execCmd *exec.Cmd + if config.Command == "" { + os.Exit(ExitCodeSuccess) + } + + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) + execCmd.Stdin = os.Stdin + execCmd.Stdout = os.Stdout + execCmd.Stderr = os.Stderr + + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if err := execCmd.Run(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // Normal command exit with non-zero code - not an SSH execution error + log.Tracef("command exited with code %d", exitError.ExitCode()) + os.Exit(exitError.ExitCode()) + } + + // Actual execution failure (command not found, permission denied, etc.) + log.Debugf("command execution failed: %v", err) + os.Exit(ExitCodeShellExecFail) + } + + os.Exit(ExitCodeSuccess) +} + +// validatePrivileges validates that privilege dropping to the target UID/GID is allowed +func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error { + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + // Allow same-user operations (no privilege dropping needed) + if uid == currentUID && gid == currentGID { + return nil + } + + // Only root can drop privileges to other users + if currentUID != 0 { + return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid) + } + + // Root can drop to any user (including root itself) + return nil +} diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go new file mode 100644 index 000000000..0c5108f57 --- /dev/null +++ b/client/ssh/server/executor_unix_test.go @@ -0,0 +1,262 @@ +//go:build unix + +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { + pd := NewPrivilegeDropper() + + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + tests := []struct { + name string + uid uint32 + gid uint32 + wantErr bool + }{ + { + name: "same user - no privilege drop needed", + uid: currentUID, + gid: currentGID, + wantErr: false, + }, + { + name: "non-root to different user should fail", + uid: currentUID + 1, // Use a different UID to ensure it's actually different + gid: currentGID + 1, // Use a different GID to ensure it's actually different + wantErr: currentUID != 0, // Only fail if current user is not root + }, + { + name: "root can drop to any user", + uid: 1000, + gid: 1000, + wantErr: false, + }, + { + name: "root can stay as root", + uid: 0, + gid: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip non-root tests when running as root, and root tests when not root + if tt.name == "non-root to different user should fail" && currentUID == 0 { + t.Skip("Skipping non-root test when running as root") + } + if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 { + t.Skip("Skipping root test when not running as root") + } + + err := pd.validatePrivileges(tt.uid, tt.gid) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000, 1001}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "ls -la", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify the command is calling netbird ssh exec + assert.Contains(t, cmd.Args, "ssh") + assert.Contains(t, cmd.Args, "exec") + assert.Contains(t, cmd.Args, "--uid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--gid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--groups") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "1001") + assert.Contains(t, cmd.Args, "--working-dir") + assert.Contains(t, cmd.Args, "/home/testuser") + assert.Contains(t, cmd.Args, "--shell") + assert.Contains(t, cmd.Args, "/bin/bash") + assert.Contains(t, cmd.Args, "--cmd") + assert.Contains(t, cmd.Args, "ls -la") +} + +func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify no command mode (command is empty so no --cmd flag) + assert.NotContains(t, cmd.Args, "--cmd") + assert.NotContains(t, cmd.Args, "--interactive") +} + +// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping +// This test requires root privileges and will be skipped if not running as root +func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("This test requires root privileges") + } + + // Find a non-root user to test with + testUser, err := findNonRootUser() + if err != nil { + t.Skip("No suitable non-root user found for testing") + } + + // Verify the user actually exists by looking it up again + _, err = user.LookupId(testUser.Uid) + if err != nil { + t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err) + } + + uid64, err := strconv.ParseUint(testUser.Uid, 10, 32) + require.NoError(t, err) + targetUID := uint32(uid64) + + gid64, err := strconv.ParseUint(testUser.Gid, 10, 32) + require.NoError(t, err) + targetGID := uint32(gid64) + + // Test in a child process to avoid affecting the test runner + if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { + pd := NewPrivilegeDropper() + + // This should succeed + err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) + require.NoError(t, err) + + // Verify we are now running as the target user + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + assert.Equal(t, targetUID, currentUID, "UID should match target") + assert.Equal(t, targetGID, currentGID, "GID should match target") + assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root") + assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group") + + return + } + + // Fork a child process to test privilege dropping + cmd := os.Args[0] + args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"} + + env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1") + + execCmd := exec.Command(cmd, args...) + execCmd.Env = env + + err = execCmd.Run() + require.NoError(t, err, "Child process should succeed") +} + +// findNonRootUser finds any non-root user on the system for testing +func findNonRootUser() (*user.User, error) { + // Try common non-root users, but avoid "nobody" on macOS due to negative UID issues + commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"} + + for _, username := range commonUsers { + if u, err := user.Lookup(username); err == nil { + // Parse as signed integer first to handle negative UIDs + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + // Skip negative UIDs (like nobody=-2 on macOS) and root + if uid64 > 0 && uid64 != 0 { + return u, nil + } + } + } + + // If no common users found, try to find any regular user with UID > 100 + // This helps on macOS where regular users start at UID 501 + allUsers := []string{"vma", "user", "test", "admin"} + for _, username := range allUsers { + if u, err := user.Lookup(username); err == nil { + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + if uid64 > 100 { // Regular user + return u, nil + } + } + } + + // If no common users found, return an error + return nil, fmt.Errorf("no suitable non-root user found on this system") +} + +func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { + pd := NewPrivilegeDropper() + currentUID := uint32(os.Geteuid()) + + if currentUID == 0 { + // When running as root, test that root can create commands for any user + config := ExecutorConfig{ + UID: 1000, // Target non-root user + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + assert.NoError(t, err, "Root should be able to create commands for any user") + assert.NotNil(t, cmd) + } else { + // When running as non-root, test that we can't drop to a different user + config := ExecutorConfig{ + UID: 0, // Try to target root + GID: 0, + Groups: []uint32{0}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + _, err := pd.CreateExecutorCommand(context.Background(), config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot drop privileges") + } +} diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go new file mode 100644 index 000000000..d3504e056 --- /dev/null +++ b/client/ssh/server/executor_windows.go @@ -0,0 +1,570 @@ +//go:build windows + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strings" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + ExitCodeSuccess = 0 + ExitCodeLogonFail = 10 + ExitCodeCreateProcessFail = 11 + ExitCodeWorkingDirFail = 12 + ExitCodeShellExecFail = 13 + ExitCodeValidationFail = 14 +) + +type WindowsExecutorConfig struct { + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Interactive bool + Pty bool + PtyWidth int + PtyHeight int +} + +type PrivilegeDropper struct{} + +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +var ( + advapi32 = windows.NewLazyDLL("advapi32.dll") + procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId") +) + +const ( + logon32LogonNetwork = 3 // Network logon - no password required for authenticated users + + // Common error messages + commandFlag = "-Command" + closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials + convertUsernameError = "convert username to UTF16: %w" + convertDomainError = "convert domain to UTF16: %w" +) + +// CreateWindowsExecutorCommand creates a Windows command with privilege dropping. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) { + if config.Username == "" { + return nil, 0, errors.New("username cannot be empty") + } + if config.Shell == "" { + return nil, 0, errors.New("shell cannot be empty") + } + + shell := config.Shell + + var shellArgs []string + if config.Command != "" { + shellArgs = []string{shell, commandFlag, config.Command} + } else { + shellArgs = []string{shell} + } + + log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + + cmd, token, err := pd.CreateWindowsProcessAsUser( + ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) + if err != nil { + return nil, 0, fmt.Errorf("create Windows process as user: %w", err) + } + + return cmd, token, nil +} + +const ( + // StatusSuccess represents successful LSA operation + StatusSuccess = 0 + + // KerbS4ULogonType message type for domain users with Kerberos + KerbS4ULogonType = 12 + // Msv10s4ulogontype message type for local users with MSV1_0 + Msv10s4ulogontype = 12 + + // MicrosoftKerberosNameA is the authentication package name for Kerberos + MicrosoftKerberosNameA = "Kerberos" + // Msv10packagename is the authentication package name for MSV1_0 + Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0" + + NameSamCompatible = 2 + NameUserPrincipal = 8 + NameCanonical = 7 + + maxUPNLen = 1024 +) + +// kerbS4ULogon structure for S4U authentication (domain users) +type kerbS4ULogon struct { + MessageType uint32 + Flags uint32 + ClientUpn unicodeString + ClientRealm unicodeString +} + +// msv10s4ulogon structure for S4U authentication (local users) +type msv10s4ulogon struct { + MessageType uint32 + Flags uint32 + UserPrincipalName unicodeString + DomainName unicodeString +} + +// unicodeString structure +type unicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +// lsaString structure +type lsaString struct { + Length uint16 + MaximumLength uint16 + Buffer *byte +} + +// tokenSource structure +type tokenSource struct { + SourceName [8]byte + SourceIdentifier windows.LUID +} + +// quotaLimits structure +type quotaLimits struct { + PagedPoolLimit uint32 + NonPagedPoolLimit uint32 + MinimumWorkingSetSize uint32 + MaximumWorkingSetSize uint32 + PagefileLimit uint32 + TimeLimit int64 +} + +var ( + secur32 = windows.NewLazyDLL("secur32.dll") + procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess") + procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage") + procLsaLogonUser = secur32.NewProc("LsaLogonUser") + procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer") + procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess") + procTranslateNameW = secur32.NewProc("TranslateNameW") +) + +// newLsaString creates an LsaString from a Go string +func newLsaString(s string) lsaString { + b := append([]byte(s), 0) + return lsaString{ + Length: uint16(len(s)), + MaximumLength: uint16(len(b)), + Buffer: &b[0], + } +} + +// generateS4UUserToken creates a Windows token using S4U authentication +// This is the exact approach OpenSSH for Windows uses for public key authentication +func generateS4UUserToken(username, domain string) (windows.Handle, error) { + userCpn := buildUserCpn(username, domain) + + pd := NewPrivilegeDropper() + isDomainUser := !pd.isLocalUser(domain) + + lsaHandle, err := initializeLsaConnection() + if err != nil { + return 0, err + } + defer cleanupLsaConnection(lsaHandle) + + authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser) + if err != nil { + return 0, err + } + + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + if err != nil { + return 0, err + } + + return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) +} + +// buildUserCpn constructs the user principal name +func buildUserCpn(username, domain string) string { + if domain != "" && domain != "." { + return fmt.Sprintf(`%s\%s`, domain, username) + } + return username +} + +// initializeLsaConnection establishes connection to LSA +func initializeLsaConnection() (windows.Handle, error) { + + processName := newLsaString("NetBird") + var mode uint32 + var lsaHandle windows.Handle + ret, _, _ := procLsaRegisterLogonProcess.Call( + uintptr(unsafe.Pointer(&processName)), + uintptr(unsafe.Pointer(&lsaHandle)), + uintptr(unsafe.Pointer(&mode)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret) + } + + return lsaHandle, nil +} + +// cleanupLsaConnection closes the LSA connection +func cleanupLsaConnection(lsaHandle windows.Handle) { + if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess { + log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret) + } +} + +// lookupAuthenticationPackage finds the correct authentication package +func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) { + var authPackageName lsaString + if isDomainUser { + authPackageName = newLsaString(MicrosoftKerberosNameA) + } else { + authPackageName = newLsaString(Msv10packagename) + } + + var authPackageId uint32 + ret, _, _ := procLsaLookupAuthenticationPackage.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&authPackageName)), + uintptr(unsafe.Pointer(&authPackageId)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret) + } + + return authPackageId, nil +} + +// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format) +func lookupPrincipalName(username, domain string) (string, error) { + samAccountName := fmt.Sprintf(`%s\%s`, domain, username) + samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName) + if err != nil { + return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err) + } + + upnBuf := make([]uint16, maxUPNLen+1) + upnSize := uint32(len(upnBuf)) + + ret, _, _ := procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameUserPrincipal), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + upn := windows.UTF16ToString(upnBuf[:upnSize]) + log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn) + return upn, nil + } + + upnSize = uint32(len(upnBuf)) + ret, _, _ = procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameCanonical), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + canonical := windows.UTF16ToString(upnBuf[:upnSize]) + slashIdx := strings.IndexByte(canonical, '/') + if slashIdx > 0 { + fqdn := canonical[:slashIdx] + upn := fmt.Sprintf("%s@%s", username, fqdn) + log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical) + return upn, nil + } + } + + log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName) + return samAccountName, nil +} + +// prepareS4ULogonStructure creates the appropriate S4U logon structure +func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { + if isDomainUser { + return prepareDomainS4ULogon(username, domain) + } + return prepareLocalS4ULogon(username) +} + +// prepareDomainS4ULogon creates S4U logon structure for domain users +func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { + upn, err := lookupPrincipalName(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("lookup principal name: %w", err) + } + + log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + + upnUtf16, err := windows.UTF16FromString(upn) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + structSize := unsafe.Sizeof(kerbS4ULogon{}) + upnByteSize := len(upnUtf16) * 2 + logonInfoSize := structSize + uintptr(upnByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*kerbS4ULogon)(logonInfo) + s4uLogon.MessageType = KerbS4ULogonType + s4uLogon.Flags = 0 + + upnOffset := structSize + upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset)) + copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16) + + s4uLogon.ClientUpn = unicodeString{ + Length: uint16((len(upnUtf16) - 1) * 2), + MaximumLength: uint16(len(upnUtf16) * 2), + Buffer: upnBuffer, + } + s4uLogon.ClientRealm = unicodeString{} + + return logonInfo, logonInfoSize, nil +} + +// prepareLocalS4ULogon creates S4U logon structure for local users +func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { + log.Debugf("using Msv1_0S4ULogon for local user: %s", username) + + usernameUtf16, err := windows.UTF16FromString(username) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + domainUtf16, err := windows.UTF16FromString(".") + if err != nil { + return nil, 0, fmt.Errorf(convertDomainError, err) + } + + structSize := unsafe.Sizeof(msv10s4ulogon{}) + usernameByteSize := len(usernameUtf16) * 2 + domainByteSize := len(domainUtf16) * 2 + logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*msv10s4ulogon)(logonInfo) + s4uLogon.MessageType = Msv10s4ulogontype + s4uLogon.Flags = 0x0 + + usernameOffset := structSize + usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset)) + copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16) + + s4uLogon.UserPrincipalName = unicodeString{ + Length: uint16((len(usernameUtf16) - 1) * 2), + MaximumLength: uint16(len(usernameUtf16) * 2), + Buffer: usernameBuffer, + } + + domainOffset := usernameOffset + uintptr(usernameByteSize) + domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset)) + copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16) + + s4uLogon.DomainName = unicodeString{ + Length: uint16((len(domainUtf16) - 1) * 2), + MaximumLength: uint16(len(domainUtf16) * 2), + Buffer: domainBuffer, + } + + return logonInfo, logonInfoSize, nil +} + +// performS4ULogon executes the S4U logon operation +func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { + var tokenSource tokenSource + copy(tokenSource.SourceName[:], "netbird") + if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { + log.Debugf("AllocateLocallyUniqueId failed") + } + + originName := newLsaString("netbird") + + var profile uintptr + var profileSize uint32 + var logonId windows.LUID + var token windows.Handle + var quotas quotaLimits + var subStatus int32 + + ret, _, _ := procLsaLogonUser.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&originName)), + logon32LogonNetwork, + uintptr(authPackageId), + uintptr(logonInfo), + logonInfoSize, + 0, + uintptr(unsafe.Pointer(&tokenSource)), + uintptr(unsafe.Pointer(&profile)), + uintptr(unsafe.Pointer(&profileSize)), + uintptr(unsafe.Pointer(&logonId)), + uintptr(unsafe.Pointer(&token)), + uintptr(unsafe.Pointer("as)), + uintptr(unsafe.Pointer(&subStatus)), + ) + + if profile != 0 { + if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { + log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + } + } + + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) + } + + log.Debugf("created S4U %s token for user %s", + map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) + return token, nil +} + +// createToken implements NetBird trust-based authentication using S4U +func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) { + fullUsername := buildUserCpn(username, domain) + + if err := userExists(fullUsername, username, domain); err != nil { + return 0, err + } + + isLocalUser := pd.isLocalUser(domain) + + if isLocalUser { + return pd.authenticateLocalUser(username, fullUsername) + } + return pd.authenticateDomainUser(username, domain, fullUsername) +} + +// userExists checks if the target useVerifier exists on the system +func userExists(fullUsername, username, domain string) error { + if _, err := lookupUser(fullUsername); err != nil { + log.Debugf("User %s not found: %v", fullUsername, err) + if domain != "" && domain != "." { + _, err = lookupUser(username) + } + if err != nil { + return fmt.Errorf("target user %s not found: %w", fullUsername, err) + } + } + return nil +} + +// isLocalUser determines if this is a local user vs domain user +func (pd *PrivilegeDropper) isLocalUser(domain string) bool { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + + return domain == "" || domain == "." || + strings.EqualFold(domain, hostname) +} + +// authenticateLocalUser handles authentication for local users +func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(username, ".") + if err != nil { + return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) + } + return token, nil +} + +// authenticateDomainUser handles authentication for domain users +func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(username, domain) + if err != nil { + return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) + } + log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + return token, nil +} + +// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables). +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) { + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("user authentication: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Debugf("close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir) + if err != nil { + return nil, 0, err + } + + return cmd, primaryToken, nil +} + +// createProcessWithToken creates process with the specified token and executable path. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) { + cmd := exec.CommandContext(ctx, executablePath, args[1:]...) + cmd.Dir = workingDir + + var primaryToken windows.Token + err := windows.DuplicateTokenEx( + sourceToken, + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityIdentification, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err) + } + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Token: syscall.Token(primaryToken), + } + + return cmd, primaryToken, nil +} + +// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) +func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { + return nil, fmt.Errorf("su command not available on Windows") +} diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go new file mode 100644 index 000000000..d36d7cbbf --- /dev/null +++ b/client/ssh/server/jwt_test.go @@ -0,0 +1,647 @@ +package server + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" + "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +func TestJWTEnforcement(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT enforcement tests in short mode") + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + t.Run("blocks_without_jwt", func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + if err != nil { + t.Logf("Detection failed: %v", err) + } + t.Logf("Detected server type: %s", serverType) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + _, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + assert.Error(t, err, "SSH connection should fail when JWT is required but not provided") + }) + + t.Run("allows_when_disabled", func(t *testing.T) { + serverConfigNoJWT := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + serverNoJWT := New(serverConfigNoJWT) + require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config") + serverNoJWT.SetAllowRootLogin(true) + + serverAddrNoJWT := StartTestServer(t, serverNoJWT) + defer require.NoError(t, serverNoJWT.Stop()) + + hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT) + require.NoError(t, err) + portNoJWT, err := strconv.Atoi(portStrNoJWT) + require.NoError(t, err) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) + assert.False(t, serverType.RequiresJWT()) + + client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT) + require.NoError(t, err) + defer client.Close() + }) + +} + +// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64RawURLEncode(n), + E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func base64RawURLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +// generateValidJWT creates a valid JWT token for testing +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} + +// connectWithNetBirdClient connects to SSH server using NetBird's SSH client +func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) { + t.Helper() + addr := net.JoinHostPort(host, strconv.Itoa(port)) + + ctx := context.Background() + return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{ + InsecureSkipVerify: true, + }) +} + +// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers +func TestJWTDetection(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT detection test in short mode") + } + + jwksServer, _, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) + assert.True(t, serverType.RequiresJWT()) +} + +func TestJWTFailClose(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT fail-close tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + tokenClaims jwt.MapClaims + }{ + { + name: "blocks_token_missing_iat", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }, + }, + { + name: "blocks_token_missing_sub", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_iss", + tokenClaims: jwt.MapClaims{ + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_aud", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_issuer", + tokenClaims: jwt.MapClaims{ + "iss": "wrong-issuer", + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_audience", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": "wrong-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_expired_token", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(-time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + { + name: "blocks_token_exceeding_max_age", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + MaxTokenAge: 3600, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims) + token.Header["kid"] = "test-key-id" + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{ + cryptossh.Password(tokenString), + }, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + assert.Error(t, err, "Authentication should fail (fail-close)") + }) + } +} + +// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types +func TestJWTAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT authentication tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + token string + wantAuthOK bool + setupServer func(*Server) + testOperation func(*testing.T, *cryptossh.Client, string) error + wantOpSuccess bool + }{ + { + name: "allows_shell_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + return session.Shell() + }, + wantOpSuccess: true, + }, + { + name: "rejects_invalid_token", + token: "invalid", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_shell_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_command_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("ls") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "allows_sftp_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + return session.RequestSubsystem("sftp") + }, + wantOpSuccess: true, + }, + { + name: "blocks_sftp_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + err = session.RequestSubsystem("sftp") + if err == nil { + err = session.Wait() + } + return err + }, + wantOpSuccess: false, + }, + { + name: "allows_port_forward_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowRemotePortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: true, + }, + { + name: "blocks_port_forward_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowLocalPortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // TODO: Skip port forwarding tests on Windows - user switching not supported + // These features are tested on Linux/Unix platforms + if runtime.GOOS == "windows" && + (tc.name == "allows_port_forward_with_jwt" || + tc.name == "blocks_port_forward_without_jwt") { + t.Skip("Skipping port forwarding test on Windows - covered by Linux tests") + } + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + if tc.setupServer != nil { + tc.setupServer(server) + } + + // Always set up authorization for test-user to ensure tests fail at JWT validation stage + testUserHash, err := sshuserhash.HashUserID("test-user") + require.NoError(t, err) + + // Get current OS username for machine user mapping + currentUser := testutil.GetTestUsername(t) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + currentUser: {0}, // Allow test-user (index 0) to access current OS user + }, + } + server.UpdateSSHAuth(authConfig) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + var authMethods []cryptossh.AuthMethod + if tc.token == "valid" { + token := generateValidJWT(t, privateKey, issuer, audience) + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(token), + } + } else if tc.token == "invalid" { + invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid" + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(invalidToken), + } + } + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: authMethods, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if tc.wantAuthOK { + require.NoError(t, err, "JWT authentication should succeed") + } else if err != nil { + t.Logf("Connection failed as expected: %v", err) + return + } + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + err = tc.testOperation(t, conn, serverAddr) + if tc.wantOpSuccess { + require.NoError(t, err, "Operation should succeed") + } else { + assert.Error(t, err, "Operation should fail") + } + }) + } +} diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go new file mode 100644 index 000000000..6138f9296 --- /dev/null +++ b/client/ssh/server/port_forwarding.go @@ -0,0 +1,386 @@ +package server + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// SessionKey uniquely identifies an SSH session +type SessionKey string + +// ConnectionKey uniquely identifies a port forwarding connection within a session +type ConnectionKey string + +// ForwardKey uniquely identifies a port forwarding listener +type ForwardKey string + +// tcpipForwardMsg represents the structure for tcpip-forward SSH requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} + +// SetAllowLocalPortForwarding configures local port forwarding +func (s *Server) SetAllowLocalPortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowLocalPortForwarding = allow +} + +// SetAllowRemotePortForwarding configures remote port forwarding +func (s *Server) SetAllowRemotePortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRemotePortForwarding = allow +} + +// configurePortForwarding sets up port forwarding callbacks +func (s *Server) configurePortForwarding(server *ssh.Server) { + allowLocal := s.allowLocalPortForwarding + allowRemote := s.allowRemotePortForwarding + + server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { + if !allowLocal { + log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { + log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort) + return true + } + + server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + if !allowRemote { + log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { + log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort) + return true + } + + log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote) +} + +// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. +// Returns nil if allowed, error if denied. +func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { + if ctx == nil { + return fmt.Errorf("%s port forwarding denied: no context", forwardType) + } + + username := ctx.User() + remoteAddr := "unknown" + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + + logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port}) + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: false, + FeatureName: forwardType + " port forwarding", + }) + + if !result.Allowed { + return result.Error + } + + logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", + forwardType, result.User.Username, port) + + return nil +} + +// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. +func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + if !s.isRemotePortForwardingAllowed() { + logger.Warnf("tcpip-forward request denied: remote port forwarding disabled") + return false, nil + } + + payload, err := s.parseTcpipForwardRequest(req) + if err != nil { + logger.Errorf("tcpip-forward unmarshal error: %v", err) + return false, nil + } + + if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil { + logger.Warnf("tcpip-forward denied: %v", err) + return false, nil + } + + logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port) + + sshConn, err := s.getSSHConnection(ctx) + if err != nil { + logger.Warnf("tcpip-forward request denied: %v", err) + return false, nil + } + + return s.setupDirectForward(ctx, logger, sshConn, payload) +} + +// cancelTcpipForwardHandler handles cancel-tcpip-forward requests. +func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + var payload tcpipForwardMsg + if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil { + logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err) + return false, nil + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + if s.removeRemoteForwardListener(key) { + logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) + return true, nil + } + + logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port) + return false, nil +} + +// handleRemoteForwardListener handles incoming connections for remote port forwarding. +func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { + log.Debugf("starting remote forward listener handler for %s:%d", host, port) + + defer func() { + log.Debugf("cleaning up remote forward listener for %s:%d", host, port) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } else { + log.Debugf("remote forward listener closed successfully for %s:%d", host, port) + } + }() + + acceptChan := make(chan acceptResult, 1) + + go func() { + for { + conn, err := ln.Accept() + select { + case acceptChan <- acceptResult{conn: conn, err: err}: + if err != nil { + return + } + case <-ctx.Done(): + return + } + } + }() + + for { + select { + case result := <-acceptChan: + if result.err != nil { + log.Debugf("remote forward accept error: %v", result.err) + return + } + go s.handleRemoteForwardConnection(ctx, result.conn, host, port) + case <-ctx.Done(): + log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) + return + } + } +} + +// getRequestLogger creates a logger with user and remote address context +func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { + remoteAddr := "unknown" + username := "unknown" + if ctx != nil { + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + username = ctx.User() + } + return log.WithFields(log.Fields{"user": username, "remote": remoteAddr}) +} + +// isRemotePortForwardingAllowed checks if remote port forwarding is enabled +func (s *Server) isRemotePortForwardingAllowed() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.allowRemotePortForwarding +} + +// parseTcpipForwardRequest parses the SSH request payload +func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { + var payload tcpipForwardMsg + err := cryptossh.Unmarshal(req.Payload, &payload) + return &payload, err +} + +// getSSHConnection extracts SSH connection from context +func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) { + if ctx == nil { + return nil, fmt.Errorf("no context") + } + sshConnValue := ctx.Value(ssh.ContextKeyConn) + if sshConnValue == nil { + return nil, fmt.Errorf("no SSH connection in context") + } + sshConn, ok := sshConnValue.(*cryptossh.ServerConn) + if !ok || sshConn == nil { + return nil, fmt.Errorf("invalid SSH connection in context") + } + return sshConn, nil +} + +// setupDirectForward sets up a direct port forward +func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) { + bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10)) + + ln, err := net.Listen("tcp", bindAddr) + if err != nil { + logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err) + return false, nil + } + + actualPort := payload.Port + if payload.Port == 0 { + tcpAddr := ln.Addr().(*net.TCPAddr) + actualPort = uint32(tcpAddr.Port) + logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + s.storeRemoteForwardListener(key, ln) + + s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) + go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) + + response := make([]byte, 4) + binary.BigEndian.PutUint32(response, actualPort) + + logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort) + return true, response +} + +// acceptResult holds the result of a listener Accept() call +type acceptResult struct { + conn net.Conn + err error +} + +// handleRemoteForwardConnection handles a single remote port forwarding connection +func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { + sessionKey := s.findSessionKeyByContext(ctx) + connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) + logger := log.WithFields(log.Fields{ + "session": sessionKey, + "conn": connID, + }) + + defer func() { + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } + }() + + sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if sshConn == nil { + logger.Debugf("remote forward: no SSH connection in context") + return + } + + remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + return + } + + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) + if err != nil { + logger.Debugf("open forward channel: %v", err) + return + } + + s.proxyForwardConnection(ctx, logger, conn, channel) +} + +// openForwardChannel creates an SSH forwarded-tcpip channel +func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { + logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port) + + payload := struct { + ConnectedAddress string + ConnectedPort uint32 + OriginatorAddress string + OriginatorPort uint32 + }{ + ConnectedAddress: host, + ConnectedPort: port, + OriginatorAddress: remoteAddr.IP.String(), + OriginatorPort: uint32(remoteAddr.Port), + } + + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload)) + if err != nil { + return nil, fmt.Errorf("open SSH channel: %w", err) + } + + go cryptossh.DiscardRequests(reqs) + return channel, nil +} + +// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel +func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(channel, conn); err != nil { + logger.Debugf("copy error (conn->channel): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(conn, channel); err != nil { + logger.Debugf("copy error (channel->conn): %v", err) + } + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + // First copy finished, wait for second copy or context cancellation + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + } + } + + if err := channel.Close(); err != nil { + logger.Debugf("channel close error: %v", err) + } + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } +} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go new file mode 100644 index 000000000..82718d002 --- /dev/null +++ b/client/ssh/server/server.go @@ -0,0 +1,751 @@ +package server + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + gojwt "github.com/golang-jwt/jwt/v5" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "golang.org/x/exp/maps" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/wgaddr" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/version" +) + +// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server +const DefaultSSHPort = 22 + +// InternalSSHPort is the port SSH server listens on and is redirected to +const InternalSSHPort = 22022 + +const ( + errWriteSession = "write session error: %v" + errExitSession = "exit session error: %v" + + msgPrivilegedUserDisabled = "privileged user login is disabled" + + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server + DefaultJWTMaxTokenAge = 5 * 60 +) + +var ( + ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) + ErrUserNotFound = errors.New("user not found") +) + +// PrivilegedUserError represents an error when privileged user login is disabled +type PrivilegedUserError struct { + Username string +} + +func (e *PrivilegedUserError) Error() string { + return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username) +} + +func (e *PrivilegedUserError) Is(target error) bool { + return target == ErrPrivilegedUserDisabled +} + +// UserNotFoundError represents an error when a user cannot be found +type UserNotFoundError struct { + Username string + Cause error +} + +func (e *UserNotFoundError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause) + } + return fmt.Sprintf("user %s not found", e.Username) +} + +func (e *UserNotFoundError) Is(target error) bool { + return target == ErrUserNotFound +} + +func (e *UserNotFoundError) Unwrap() error { + return e.Cause +} + +// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors +func logSessionExitError(logger *log.Entry, err error) { + if err != nil && !errors.Is(err, io.EOF) { + logger.Warnf(errExitSession, err) + } +} + +// safeLogCommand returns a safe representation of the command for logging +func safeLogCommand(cmd []string) string { + if len(cmd) == 0 { + return "" + } + if len(cmd) == 1 { + return cmd[0] + } + return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) +} + +type sshConnectionState struct { + hasActivePortForward bool + username string + remoteAddr string +} + +type authKey string + +func newAuthKey(username string, remoteAddr net.Addr) authKey { + return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) +} + +type Server struct { + sshServer *ssh.Server + mu sync.RWMutex + hostKeyPEM []byte + sessions map[SessionKey]ssh.Session + sessionCancels map[ConnectionKey]context.CancelFunc + sessionJWTUsers map[SessionKey]string + pendingAuthJWT map[authKey]string + + allowLocalPortForwarding bool + allowRemotePortForwarding bool + allowRootLogin bool + allowSFTP bool + jwtEnabled bool + + netstackNet *netstack.Net + + wgAddress wgaddr.Address + + remoteForwardListeners map[ForwardKey]net.Listener + sshConnections map[*cryptossh.ServerConn]*sshConnectionState + + jwtValidator *jwt.Validator + jwtExtractor *jwt.ClaimsExtractor + jwtConfig *JWTConfig + + authorizer *sshauth.Authorizer + + suSupportsPty bool + loginIsUtilLinux bool +} + +type JWTConfig struct { + Issuer string + Audience string + KeysLocation string + MaxTokenAge int64 +} + +// Config contains all SSH server configuration options +type Config struct { + // JWT authentication configuration. If nil, JWT authentication is disabled + JWT *JWTConfig + + // HostKey is the SSH server host key in PEM format + HostKeyPEM []byte +} + +// SessionInfo contains information about an active SSH session +type SessionInfo struct { + Username string + RemoteAddress string + Command string + JWTUsername string +} + +// New creates an SSH server instance with the provided host key and optional JWT configuration +// If jwtConfig is nil, JWT authentication is disabled +func New(config *Config) *Server { + s := &Server{ + mu: sync.RWMutex{}, + hostKeyPEM: config.HostKeyPEM, + sessions: make(map[SessionKey]ssh.Session), + sessionJWTUsers: make(map[SessionKey]string), + pendingAuthJWT: make(map[authKey]string), + remoteForwardListeners: make(map[ForwardKey]net.Listener), + sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), + jwtEnabled: config.JWT != nil, + jwtConfig: config.JWT, + authorizer: sshauth.NewAuthorizer(), // Initialize with empty config + } + + return s +} + +// Start runs the SSH server +func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer != nil { + return errors.New("SSH server is already running") + } + + s.suSupportsPty = s.detectSuPtySupport(ctx) + s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx) + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + return fmt.Errorf("create listener: %w", err) + } + + sshServer, err := s.createSSHServer(ln.Addr()) + if err != nil { + s.closeListener(ln) + return fmt.Errorf("create SSH server: %w", err) + } + + s.sshServer = sshServer + log.Infof("SSH server started on %s", addrDesc) + + go func() { + if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error: %v", err) + } + }() + return nil +} + +func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { + if s.netstackNet != nil { + ln, err := s.netstackNet.ListenTCPAddrPort(addr) + if err != nil { + return nil, "", fmt.Errorf("listen on netstack: %w", err) + } + return ln, fmt.Sprintf("netstack %s", addr), nil + } + + tcpAddr := net.TCPAddrFromAddrPort(addr) + lc := net.ListenConfig{} + ln, err := lc.Listen(ctx, "tcp", tcpAddr.String()) + if err != nil { + return nil, "", fmt.Errorf("listen: %w", err) + } + return ln, addr.String(), nil +} + +func (s *Server) closeListener(ln net.Listener) { + if ln == nil { + return + } + if err := ln.Close(); err != nil { + log.Debugf("listener close error: %v", err) + } +} + +// Stop closes the SSH server +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer == nil { + return nil + } + + if err := s.sshServer.Close(); err != nil { + log.Debugf("close SSH server: %v", err) + } + + s.sshServer = nil + + maps.Clear(s.sessions) + maps.Clear(s.sessionJWTUsers) + maps.Clear(s.pendingAuthJWT) + maps.Clear(s.sshConnections) + + for _, cancelFunc := range s.sessionCancels { + cancelFunc() + } + maps.Clear(s.sessionCancels) + + for _, listener := range s.remoteForwardListeners { + if err := listener.Close(); err != nil { + log.Debugf("close remote forward listener: %v", err) + } + } + maps.Clear(s.remoteForwardListeners) + + return nil +} + +// GetStatus returns the current status of the SSH server and active sessions +func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { + s.mu.RLock() + defer s.mu.RUnlock() + + enabled = s.sshServer != nil + + for sessionKey, session := range s.sessions { + cmd := "" + if len(session.Command()) > 0 { + cmd = safeLogCommand(session.Command()) + } + + jwtUsername := s.sessionJWTUsers[sessionKey] + + sessions = append(sessions, SessionInfo{ + Username: session.User(), + RemoteAddress: session.RemoteAddr().String(), + Command: cmd, + JWTUsername: jwtUsername, + }) + } + + return enabled, sessions +} + +// SetNetstackNet sets the netstack network for userspace networking +func (s *Server) SetNetstackNet(net *netstack.Net) { + s.mu.Lock() + defer s.mu.Unlock() + s.netstackNet = net +} + +// SetNetworkValidation configures network-based connection filtering +func (s *Server) SetNetworkValidation(addr wgaddr.Address) { + s.mu.Lock() + defer s.mu.Unlock() + s.wgAddress = addr +} + +// UpdateSSHAuth updates the SSH fine-grained access control configuration +// This should be called when network map updates include new SSH auth configuration +func (s *Server) UpdateSSHAuth(config *sshauth.Config) { + s.mu.Lock() + defer s.mu.Unlock() + + // Reset JWT validator/extractor to pick up new userIDClaim + s.jwtValidator = nil + s.jwtExtractor = nil + + s.authorizer.Update(config) +} + +// ensureJWTValidator initializes the JWT validator and extractor if not already initialized +func (s *Server) ensureJWTValidator() error { + s.mu.RLock() + if s.jwtValidator != nil && s.jwtExtractor != nil { + s.mu.RUnlock() + return nil + } + config := s.jwtConfig + authorizer := s.authorizer + s.mu.RUnlock() + + if config == nil { + return fmt.Errorf("JWT config not set") + } + + log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience) + + validator := jwt.NewValidator( + config.Issuer, + []string{config.Audience}, + config.KeysLocation, + true, + ) + + // Use custom userIDClaim from authorizer if available + extractorOptions := []jwt.ClaimsExtractorOption{ + jwt.WithAudience(config.Audience), + } + if authorizer.GetUserIDClaim() != "" { + extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim())) + log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim()) + } + + extractor := jwt.NewClaimsExtractor(extractorOptions...) + + s.mu.Lock() + defer s.mu.Unlock() + + if s.jwtValidator != nil && s.jwtExtractor != nil { + return nil + } + + s.jwtValidator = validator + s.jwtExtractor = extractor + + log.Infof("JWT validator initialized successfully") + return nil +} + +func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) { + s.mu.RLock() + jwtValidator := s.jwtValidator + jwtConfig := s.jwtConfig + s.mu.RUnlock() + + if jwtValidator == nil { + return nil, fmt.Errorf("JWT validator not initialized") + } + + token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString) + if err != nil { + if jwtConfig != nil { + if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil { + return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w", + jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err) + } + } + return nil, fmt.Errorf("validate token: %w", err) + } + + if err := s.checkTokenAge(token, jwtConfig); err != nil { + return nil, err + } + + return token, nil +} + +func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error { + if jwtConfig == nil { + return nil + } + + maxTokenAge := jwtConfig.MaxTokenAge + if maxTokenAge <= 0 { + maxTokenAge = DefaultJWTMaxTokenAge + } + + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token has invalid claims format (user=%s)", userID) + } + + iat, ok := claims["iat"].(float64) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token missing iat claim (user=%s)", userID) + } + + issuedAt := time.Unix(int64(iat), 0) + tokenAge := time.Since(issuedAt) + maxAge := time.Duration(maxTokenAge) * time.Second + if tokenAge > maxAge { + userID := getUserIDFromClaims(claims) + return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge) + } + + return nil +} + +func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) { + s.mu.RLock() + jwtExtractor := s.jwtExtractor + s.mu.RUnlock() + + if jwtExtractor == nil { + userID := extractUserID(token) + return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID) + } + + userAuth, err := jwtExtractor.ToUserAuth(token) + if err != nil { + userID := extractUserID(token) + return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err) + } + + if !s.hasSSHAccess(&userAuth) { + return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId) + } + + return &userAuth, nil +} + +func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool { + return userAuth.UserId != "" +} + +func extractUserID(token *gojwt.Token) string { + if token == nil { + return "unknown" + } + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + return "unknown" + } + return getUserIDFromClaims(claims) +} + +func getUserIDFromClaims(claims gojwt.MapClaims) string { + if sub, ok := claims["sub"].(string); ok && sub != "" { + return sub + } + if userID, ok := claims["user_id"].(string); ok && userID != "" { + return userID + } + if email, ok := claims["email"].(string); ok && email != "" { + return email + } + return "unknown" +} + +func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("decode payload: %w", err) + } + + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("parse claims: %w", err) + } + + return claims, nil +} + +func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { + osUsername := ctx.User() + remoteAddr := ctx.RemoteAddr() + + if err := s.ensureJWTValidator(); err != nil { + log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err) + return false + } + + token, err := s.validateJWTToken(password) + if err != nil { + log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err) + return false + } + + userAuth, err := s.extractAndValidateUser(token) + if err != nil { + log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err) + return false + } + + s.mu.RLock() + authorizer := s.authorizer + s.mu.RUnlock() + + if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil { + log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err) + return false + } + + key := newAuthKey(osUsername, remoteAddr) + s.mu.Lock() + s.pendingAuthJWT[key] = userAuth.UserId + s.mu.Unlock() + + log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr) + return true +} + +func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + if state, exists := s.sshConnections[sshConn]; exists { + state.hasActivePortForward = true + } else { + s.sshConnections[sshConn] = &sshConnectionState{ + hasActivePortForward: true, + username: username, + remoteAddr: remoteAddr, + } + } +} + +func (s *Server) connectionCloseHandler(conn net.Conn, err error) { + // We can't extract the SSH connection from net.Conn directly + // Connection cleanup will happen during session cleanup or via timeout + log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) +} + +func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { + if ctx == nil { + return "unknown" + } + + // Try to match by SSH connection + sshConn := ctx.Value(ssh.ContextKeyConn) + if sshConn == nil { + return "unknown" + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Look through sessions to find one with matching connection + for sessionKey, session := range s.sessions { + if session.Context().Value(ssh.ContextKeyConn) == sshConn { + return sessionKey + } + } + + // If no session found, this might be during early connection setup + // Return a temporary key that we'll fix up later + if ctx.User() != "" && ctx.RemoteAddr() != nil { + tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) + log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey) + return tempKey + } + + return "unknown" +} + +func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { + s.mu.RLock() + netbirdNetwork := s.wgAddress.Network + localIP := s.wgAddress.IP + s.mu.RUnlock() + + if !netbirdNetwork.IsValid() || !localIP.IsValid() { + return conn + } + + remoteAddr := conn.RemoteAddr() + tcpAddr, ok := remoteAddr.(*net.TCPAddr) + if !ok { + log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr) + return nil + } + + remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP) + if !ok { + log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) + return nil + } + + // Block connections from our own IP (prevent local apps from connecting to ourselves) + if remoteIP == localIP { + log.Warnf("SSH connection rejected from own IP %s", remoteIP) + return nil + } + + if !netbirdNetwork.Contains(remoteIP) { + log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) + return nil + } + + log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) + return conn +} + +func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { + if err := enableUserSwitching(); err != nil { + log.Warnf("failed to enable user switching: %v", err) + } + + serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion()) + if s.jwtEnabled { + serverVersion += " " + detection.JWTRequiredMarker + } + + server := &ssh.Server{ + Addr: addr.String(), + Handler: s.sessionHandler, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpSubsystemHandler, + }, + HostSigners: []ssh.Signer{}, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": s.directTCPIPHandler, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": s.tcpipForwardHandler, + "cancel-tcpip-forward": s.cancelTcpipForwardHandler, + }, + ConnCallback: s.connectionValidator, + ConnectionFailedCallback: s.connectionCloseHandler, + Version: serverVersion, + } + + if s.jwtEnabled { + server.PasswordHandler = s.passwordHandler + } + + hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM) + if err := server.SetOption(hostKeyPEM); err != nil { + return nil, fmt.Errorf("set host key: %w", err) + } + + s.configurePortForwarding(server) + return server, nil +} + +func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { + s.mu.Lock() + defer s.mu.Unlock() + s.remoteForwardListeners[key] = ln +} + +func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { + s.mu.Lock() + defer s.mu.Unlock() + + ln, exists := s.remoteForwardListeners[key] + if !exists { + return false + } + + delete(s.remoteForwardListeners, key) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } + + return true +} + +func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { + var payload struct { + Host string + Port uint32 + OriginatorAddr string + OriginatorPort uint32 + } + + if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { + if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { + log.Debugf("channel reject error: %v", err) + } + return + } + + s.mu.RLock() + allowLocal := s.allowLocalPortForwarding + s.mu.RUnlock() + + if !allowLocal { + log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) + _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") + return + } + + // Check privilege requirements for the destination port + if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { + log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") + return + } + + log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + + ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) +} diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go new file mode 100644 index 000000000..24e455025 --- /dev/null +++ b/client/ssh/server/server_config_test.go @@ -0,0 +1,394 @@ +package server + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" +) + +func TestServer_RootLoginRestriction(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "root login allowed", + allowRoot: true, + username: "root", + expectError: false, + description: "Root login should succeed when allowed", + }, + { + name: "root login denied", + allowRoot: false, + username: "root", + expectError: true, + description: "Root login should fail when disabled", + }, + { + name: "regular user login always allowed", + allowRoot: false, + username: "testuser", + expectError: false, + description: "Regular user login should work regardless of root setting", + }, + } + + // Add Windows Administrator tests if on Windows + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "Administrator login allowed", + allowRoot: true, + username: "Administrator", + expectError: false, + description: "Administrator login should succeed when allowed", + }, + { + name: "Administrator login denied", + allowRoot: false, + username: "Administrator", + expectError: true, + description: "Administrator login should fail when disabled", + }, + { + name: "administrator login denied (lowercase)", + allowRoot: false, + username: "administrator", + expectError: true, + description: "administrator login should fail when disabled (case insensitive)", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to test root access controls + // Set up mock users based on platform + mockUsers := map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"), + } + + // Add Windows-specific users for Administrator tests + if runtime.GOOS == "windows" { + mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator") + mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator") + } + + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + mockUsers, + nil, + ) + defer cleanup() + + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(tt.allowRoot) + + // Test the userNameLookup method directly + user, err := server.userNameLookup(tt.username) + + if tt.expectError { + assert.Error(t, err, tt.description) + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // Check for appropriate error message based on platform capabilities + errorMsg := err.Error() + // Either privileged user restriction OR user switching limitation + hasPrivilegedError := strings.Contains(errorMsg, "privileged user") + hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported") + assert.True(t, hasPrivilegedError || hasSwitchingError, + "Expected privileged user or user switching error, got: %s", errorMsg) + } + } else { + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // For privileged users, we expect either success or a different error + // (like user not found), but not the "login disabled" error + if err != nil { + assert.NotContains(t, err.Error(), "privileged user login is disabled") + } + } else { + // For regular users, lookup should generally succeed or fall back gracefully + // Note: may return current user as fallback + assert.NotNil(t, user) + } + } + }) + } +} + +func TestServer_PortForwardingRestriction(t *testing.T) { + // Test that the port forwarding callbacks properly respect configuration flags + // This is a unit test of the callback logic, not a full integration test + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowLocalForwarding bool + allowRemoteForwarding bool + description string + }{ + { + name: "all forwarding allowed", + allowLocalForwarding: true, + allowRemoteForwarding: true, + description: "Both local and remote forwarding should be allowed", + }, + { + name: "local forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: true, + description: "Local forwarding should be denied when disabled", + }, + { + name: "remote forwarding disabled", + allowLocalForwarding: true, + allowRemoteForwarding: false, + description: "Remote forwarding should be denied when disabled", + }, + { + name: "all forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: false, + description: "Both forwarding types should be denied when disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) + server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) + + // We need to access the internal configuration to simulate the callback tests + // Since the callbacks are created inside the Start method, we'll test the logic directly + + // Test the configuration values are set correctly + server.mu.RLock() + allowLocal := server.allowLocalPortForwarding + allowRemote := server.allowRemotePortForwarding + server.mu.RUnlock() + + assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly") + assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly") + + // Simulate the callback logic + localResult := allowLocal // This would be the callback return value + remoteResult := allowRemote // This would be the callback return value + + assert.Equal(t, tt.allowLocalForwarding, localResult, + "Local port forwarding callback should return correct value") + assert.Equal(t, tt.allowRemoteForwarding, remoteResult, + "Remote port forwarding callback should return correct value") + }) + } +} + +func TestServer_PortConflictHandling(t *testing.T) { + // Test that multiple sessions requesting the same local port are handled naturally by the OS + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Get a free port for testing + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + testPort := ln.Addr().(*net.TCPAddr).Port + err = ln.Close() + require.NoError(t, err) + + // Connect first client + ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel1() + + client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client1.Close() + assert.NoError(t, err) + }() + + // Connect second client + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client2.Close() + assert.NoError(t, err) + }() + + // First client binds to the test port + localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort) + remoteAddr := "127.0.0.1:80" + + // Start first client's port forwarding + done1 := make(chan error, 1) + go func() { + // This should succeed and hold the port + err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr) + done1 <- err + }() + + // Give first client time to bind + time.Sleep(200 * time.Millisecond) + + // Second client tries to bind to same port + localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort) + + shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer shortCancel() + + err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr) + // Second client should fail due to "address already in use" + assert.Error(t, err, "Second client should fail to bind to same port") + if err != nil { + // The error should indicate the address is already in use + errMsg := strings.ToLower(err.Error()) + if runtime.GOOS == "windows" { + assert.Contains(t, errMsg, "only one usage of each socket address", + "Error should indicate port conflict") + } else { + assert.Contains(t, errMsg, "address already in use", + "Error should indicate port conflict") + } + } + + // Cancel first client's context and wait for it to finish + cancel1() + select { + case err1 := <-done1: + // Should get context cancelled or deadline exceeded + assert.Error(t, err1, "First client should exit when context cancelled") + case <-time.After(2 * time.Second): + t.Error("First client did not exit within timeout") + } +} + +func TestServer_IsPrivilegedUser(t *testing.T) { + + tests := []struct { + username string + expected bool + description string + }{ + { + username: "root", + expected: true, + description: "root should be considered privileged", + }, + { + username: "regular", + expected: false, + description: "regular user should not be privileged", + }, + { + username: "", + expected: false, + description: "empty username should not be privileged", + }, + } + + // Add Windows-specific tests + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: true, + description: "Administrator should be considered privileged on Windows", + }, + { + username: "administrator", + expected: true, + description: "administrator should be considered privileged on Windows (case insensitive)", + }, + }...) + } else { + // On non-Windows systems, Administrator should not be privileged + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: false, + description: "Administrator should not be privileged on non-Windows systems", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go new file mode 100644 index 000000000..661068539 --- /dev/null +++ b/client/ssh/server/server_test.go @@ -0,0 +1,441 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os/user" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +func TestServer_StartStop(t *testing.T) { + key, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: key, + JWT: nil, + } + server := New(serverConfig) + + err = server.Stop() + assert.NoError(t, err) +} + +func TestSSHServerIntegration(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server with random port + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server in background + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + // Get a free port + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key for verification + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Create SSH client config + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Connect to SSH server + client, err := cryptossh.Dial("tcp", serverAddr, config) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("close client: %v", err) + } + }() + + // Test creating a session + session, err := client.NewSession() + require.NoError(t, err) + defer func() { + if err := session.Close(); err != nil { + t.Logf("close session: %v", err) + } + }() + + // Note: Since we don't have a real shell environment in tests, + // we can't test actual command execution, but we can verify + // the connection and authentication work + t.Log("SSH connection and authentication successful") +} + +func TestSSHServerMultipleConnections(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Test multiple concurrent connections + const numConnections = 5 + results := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + go func(id int) { + client, err := cryptossh.Dial("tcp", serverAddr, config) + if err != nil { + results <- fmt.Errorf("connection %d failed: %w", id, err) + return + } + defer func() { + _ = client.Close() // Ignore error in test goroutine + }() + + session, err := client.NewSession() + if err != nil { + results <- fmt.Errorf("session %d failed: %w", id, err) + return + } + defer func() { + _ = session.Close() // Ignore error in test goroutine + }() + + results <- nil + }(i) + } + + // Wait for all connections to complete + for i := 0; i < numConnections; i++ { + select { + case err := <-results: + assert.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatalf("Connection %d timed out", i) + } + } +} + +func TestSSHServerNoAuthMode(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate a client private key for SSH protocol (server doesn't check it) + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Try to connect with client key + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(clientSigner), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // This should succeed in no-auth mode (server doesn't verify keys) + conn, err := cryptossh.Dial("tcp", serverAddr, config) + assert.NoError(t, err, "Connection should succeed in no-auth mode") + if conn != nil { + assert.NoError(t, conn.Close()) + } +} + +func TestSSHServerStartStopCycle(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + serverAddr := "127.0.0.1:0" + + // Test multiple start/stop cycles + for i := 0; i < 3; i++ { + t.Logf("Start/stop cycle %d", i+1) + + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case <-started: + case err := <-errChan: + t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err) + case <-time.After(5 * time.Second): + t.Fatalf("Cycle %d: Server start timeout", i+1) + } + + err = server.Stop() + require.NoError(t, err, "Cycle %d: Stop should succeed", i+1) + } +} + +func TestSSHServer_WindowsShellHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping Windows shell test in short mode") + } + + server := &Server{} + + if runtime.GOOS == "windows" { + // Test Windows cmd.exe shell behavior + args := server.getShellCommandArgs("cmd.exe", "echo test") + assert.Equal(t, "cmd.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + + // Test PowerShell behavior + args = server.getShellCommandArgs("powershell.exe", "echo test") + assert.Equal(t, "powershell.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + } else { + // Test Unix shell behavior + args := server.getShellCommandArgs("/bin/sh", "echo test") + assert.Equal(t, "/bin/sh", args[0]) + assert.Equal(t, "-l", args[1]) + assert.Equal(t, "-c", args[2]) + assert.Equal(t, "echo test", args[3]) + } +} + +func TestSSHServer_PortForwardingConfiguration(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig1 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server1 := New(serverConfig1) + + serverConfig2 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server2 := New(serverConfig2) + + assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security") + assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security") + + server2.SetAllowLocalPortForwarding(true) + server2.SetAllowRemotePortForwarding(true) + + assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set") + assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set") +} diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go new file mode 100644 index 000000000..4e6d72098 --- /dev/null +++ b/client/ssh/server/session_handlers.go @@ -0,0 +1,168 @@ +package server + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// sessionHandler handles SSH sessions +func (s *Server) sessionHandler(session ssh.Session) { + sessionKey := s.registerSession(session) + + key := newAuthKey(session.User(), session.RemoteAddr()) + s.mu.Lock() + jwtUsername := s.pendingAuthJWT[key] + if jwtUsername != "" { + s.sessionJWTUsers[sessionKey] = jwtUsername + delete(s.pendingAuthJWT, key) + } + s.mu.Unlock() + + logger := log.WithField("session", sessionKey) + if jwtUsername != "" { + logger = logger.WithField("jwt_user", jwtUsername) + logger.Infof("SSH session started (JWT user: %s)", jwtUsername) + } else { + logger.Infof("SSH session started") + } + sessionStart := time.Now() + + defer s.unregisterSession(sessionKey, session) + defer func() { + duration := time.Since(sessionStart).Round(time.Millisecond) + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Warnf("close session after %v: %v", duration, err) + } + logger.Infof("SSH session closed after %v", duration) + }() + + privilegeResult, err := s.userPrivilegeCheck(session.User()) + if err != nil { + s.handlePrivError(logger, session, err) + return + } + + ptyReq, winCh, isPty := session.Pty() + hasCommand := len(session.Command()) > 0 + + switch { + case isPty && hasCommand: + // ssh -t - Pty command execution + s.handleCommand(logger, session, privilegeResult, winCh) + case isPty: + // ssh - Pty interactive session (login) + s.handlePty(logger, session, privilegeResult, ptyReq, winCh) + case hasCommand: + // ssh - non-Pty command execution + s.handleCommand(logger, session, privilegeResult, nil) + default: + s.rejectInvalidSession(logger, session) + } +} + +func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { + if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) +} + +func (s *Server) registerSession(session ssh.Session) SessionKey { + sessionID := session.Context().Value(ssh.ContextKeySessionID) + if sessionID == nil { + sessionID = fmt.Sprintf("%p", session) + } + + // Create a short 4-byte identifier from the full session ID + hasher := sha256.New() + hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) + hash := hasher.Sum(nil) + shortID := hex.EncodeToString(hash[:4]) + + remoteAddr := session.RemoteAddr().String() + username := session.User() + sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) + + s.mu.Lock() + s.sessions[sessionKey] = session + s.mu.Unlock() + + return sessionKey +} + +func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { + s.mu.Lock() + delete(s.sessions, sessionKey) + delete(s.sessionJWTUsers, sessionKey) + + // Cancel all port forwarding connections for this session + var connectionsToCancel []ConnectionKey + for key := range s.sessionCancels { + if strings.HasPrefix(string(key), string(sessionKey)+"-") { + connectionsToCancel = append(connectionsToCancel, key) + } + } + + for _, key := range connectionsToCancel { + if cancelFunc, exists := s.sessionCancels[key]; exists { + log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key) + cancelFunc() + delete(s.sessionCancels, key) + } + } + + if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil { + if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok { + delete(s.sshConnections, sshConn) + } + } + + s.mu.Unlock() +} + +func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { + logger.Warnf("user privilege check failed: %v", err) + + errorMsg := s.buildUserLookupErrorMessage(err) + + if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if exitErr := session.Exit(1); exitErr != nil { + logSessionExitError(logger, exitErr) + } +} + +// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type +func (s *Server) buildUserLookupErrorMessage(err error) string { + var privilegedErr *PrivilegedUserError + + switch { + case errors.As(err, &privilegedErr): + if privilegedErr.Username == "root" { + return "root login is disabled on this SSH server\n" + } + return "privileged user access is disabled on this SSH server\n" + + case errors.Is(err, ErrPrivilegeRequired): + return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n" + + case errors.Is(err, ErrPrivilegedUserSwitch): + return "Cannot switch to privileged user - current user lacks required privileges\n" + + default: + return "User authentication failed\n" + } +} diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go new file mode 100644 index 000000000..c35e4da0b --- /dev/null +++ b/client/ssh/server/session_handlers_js.go @@ -0,0 +1,22 @@ +//go:build js + +package server + +import ( + "fmt" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handlePty is not supported on JS/WASM +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { + errorMsg := "PTY sessions are not supported on WASM/JS platform\n" + if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go new file mode 100644 index 000000000..c2b9f552b --- /dev/null +++ b/client/ssh/server/sftp.go @@ -0,0 +1,81 @@ +package server + +import ( + "fmt" + "io" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" +) + +// SetAllowSFTP enables or disables SFTP support +func (s *Server) SetAllowSFTP(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowSFTP = allow +} + +// sftpSubsystemHandler handles SFTP subsystem requests +func (s *Server) sftpSubsystemHandler(sess ssh.Session) { + s.mu.RLock() + allowSFTP := s.allowSFTP + s.mu.RUnlock() + + if !allowSFTP { + log.Debugf("SFTP subsystem request denied: SFTP disabled") + if err := sess.Exit(1); err != nil { + log.Debugf("SFTP session exit failed: %v", err) + } + return + } + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: sess.User(), + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSFTP, + }) + + if !result.Allowed { + log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) + if err := sess.Exit(1); err != nil { + log.Debugf("exit SFTP session: %v", err) + } + return + } + + log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username) + + if !result.RequiresUserSwitching { + if err := s.executeSftpDirect(sess); err != nil { + log.Errorf("SFTP direct execution: %v", err) + } + return + } + + if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { + log.Errorf("SFTP privilege drop execution: %v", err) + } +} + +// executeSftpDirect executes SFTP directly without privilege dropping +func (s *Server) executeSftpDirect(sess ssh.Session) error { + log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User()) + + sftpServer, err := sftp.NewServer(sess) + if err != nil { + return fmt.Errorf("SFTP server creation: %w", err) + } + + defer func() { + if err := sftpServer.Close(); err != nil { + log.Debugf("failed to close sftp server: %v", err) + } + }() + + if err := sftpServer.Serve(); err != nil && err != io.EOF { + return fmt.Errorf("serve: %w", err) + } + + return nil +} diff --git a/client/ssh/server/sftp_js.go b/client/ssh/server/sftp_js.go new file mode 100644 index 000000000..3b27aeff4 --- /dev/null +++ b/client/ssh/server/sftp_js.go @@ -0,0 +1,12 @@ +//go:build js + +package server + +import ( + "os/user" +) + +// parseUserCredentials is not supported on JS/WASM +func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) { + return 0, 0, nil, errNotSupported +} diff --git a/client/ssh/server/sftp_test.go b/client/ssh/server/sftp_test.go new file mode 100644 index 000000000..32a3643e4 --- /dev/null +++ b/client/ssh/server/sftp_test.go @@ -0,0 +1,228 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os" + "os/user" + "testing" + "time" + + "github.com/pkg/sftp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" +) + +func TestSSHServer_SFTPSubsystem(t *testing.T) { + // Skip SFTP test when running as root due to protocol issues in some environments + if os.Geteuid() == 0 { + t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues") + } + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP enabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(true) + server.SetAllowRootLogin(true) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Create SFTP client + sftpClient, err := sftp.NewClient(conn) + require.NoError(t, err, "SFTP client creation should succeed") + defer func() { + if err := sftpClient.Close(); err != nil { + t.Logf("SFTP client close error: %v", err) + } + }() + + // Test basic SFTP operations + workingDir, err := sftpClient.Getwd() + assert.NoError(t, err, "Should be able to get working directory") + assert.NotEmpty(t, workingDir, "Working directory should not be empty") + + // Test directory listing + files, err := sftpClient.ReadDir(".") + assert.NoError(t, err, "Should be able to list current directory") + assert.NotNil(t, files, "File list should not be nil") +} + +func TestSSHServer_SFTPDisabled(t *testing.T) { + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP disabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(false) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Try to create SFTP client - should fail when SFTP is disabled + _, err = sftp.NewClient(conn) + assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled") +} diff --git a/client/ssh/server/sftp_unix.go b/client/ssh/server/sftp_unix.go new file mode 100644 index 000000000..44202bead --- /dev/null +++ b/client/ssh/server/sftp_unix.go @@ -0,0 +1,71 @@ +//go:build !windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + uid, gid, groups, err := s.parseUserCredentials(targetUser) + if err != nil { + return fmt.Errorf("parse user credentials: %w", err) + } + + sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir) + if err != nil { + return fmt.Errorf("create executor: %w", err) + } + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid) + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting SFTP executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("SFTP process exited with code %d", exitError.ExitCode()) + return nil + } + return fmt.Errorf("exec: %w", err) + } + + return nil +} + +// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping +func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, err + } + + args := []string{ + "ssh", "sftp", + "--uid", strconv.FormatUint(uint64(uid), 10), + "--gid", strconv.FormatUint(uint64(gid), 10), + "--working-dir", workingDir, + } + + for _, group := range groups { + args = append(args, "--groups", strconv.FormatUint(uint64(group), 10)) + } + + log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args) + return exec.CommandContext(sess.Context(), netbirdPath, args...), nil +} diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go new file mode 100644 index 000000000..dc532b9e7 --- /dev/null +++ b/client/ssh/server/sftp_windows.go @@ -0,0 +1,91 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// createSftpCommand creates a Windows SFTP command with user switching. +// The caller must close the returned token handle after starting the process. +func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) { + username, domain := s.parseUsername(targetUser.Username) + + netbirdPath, err := os.Executable() + if err != nil { + return nil, 0, fmt.Errorf("get netbird executable path: %w", err) + } + + args := []string{ + "ssh", "sftp", + "--working-dir", targetUser.HomeDir, + "--windows-username", username, + "--windows-domain", domain, + } + + pd := NewPrivilegeDropper() + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("create token: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Warnf("failed to close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir) + if err != nil { + return nil, 0, fmt.Errorf("create SFTP command: %w", err) + } + + log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username) + return cmd, primaryToken, nil +} + +// executeSftpCommand executes a Windows SFTP command with proper I/O handling +func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error { + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting sftp executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("sftp process exited with code %d", exitError.ExitCode()) + return nil + } + + return fmt.Errorf("exec sftp: %w", err) + } + + return nil +} + +// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + sftpCmd, token, err := s.createSftpCommand(targetUser, sess) + if err != nil { + return fmt.Errorf("create sftp: %w", err) + } + return s.executeSftpCommand(sess, sftpCmd, token) +} diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go new file mode 100644 index 000000000..fea9d2910 --- /dev/null +++ b/client/ssh/server/shell.go @@ -0,0 +1,180 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "os" + "os/exec" + "os/user" + "runtime" + "strconv" + "strings" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +const ( + defaultUnixShell = "/bin/sh" + + pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name + powershellExe = "powershell.exe" +) + +// getUserShell returns the appropriate shell for the given user ID +// Handles all platform-specific logic and fallbacks consistently +func getUserShell(userID string) string { + switch runtime.GOOS { + case "windows": + return getWindowsUserShell() + default: + return getUnixUserShell(userID) + } +} + +// getWindowsUserShell returns the best shell for Windows users. +// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection +// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters. +// PowerShell provides safer argument handling and is available on all modern Windows systems. +// Order: pwsh.exe -> powershell.exe +func getWindowsUserShell() string { + if path, err := exec.LookPath(pwshExe); err == nil { + return path + } + if path, err := exec.LookPath(powershellExe); err == nil { + return path + } + + return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` +} + +// getUnixUserShell returns the shell for Unix-like systems +func getUnixUserShell(userID string) string { + shell := getShellFromPasswd(userID) + if shell != "" { + return shell + } + + if shell := os.Getenv("SHELL"); shell != "" { + return shell + } + + return defaultUnixShell +} + +// getShellFromPasswd reads the shell from /etc/passwd for the given user ID +func getShellFromPasswd(userID string) string { + file, err := os.Open("/etc/passwd") + if err != nil { + return "" + } + defer func() { + if err := file.Close(); err != nil { + log.Warnf("close /etc/passwd file: %v", err) + } + }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Split(line, ":") + if len(fields) < 7 { + continue + } + + // field 2 is UID + if fields[2] == userID { + shell := strings.TrimSpace(fields[6]) + return shell + } + } + + if err := scanner.Err(); err != nil { + log.Warnf("error reading /etc/passwd: %v", err) + } + + return "" +} + +// prepareUserEnv prepares environment variables for user execution +func prepareUserEnv(user *user.User, shell string) []string { + pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" + if runtime.GOOS == "windows" { + pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0` + } + + return []string{ + fmt.Sprint("SHELL=" + shell), + fmt.Sprint("USER=" + user.Username), + fmt.Sprint("LOGNAME=" + user.Username), + fmt.Sprint("HOME=" + user.HomeDir), + "PATH=" + pathValue, + } +} + +// acceptEnv checks if environment variable from SSH client should be accepted +// This is a whitelist of variables that SSH clients can send to the server +func acceptEnv(envVar string) bool { + varName := envVar + if idx := strings.Index(envVar, "="); idx != -1 { + varName = envVar[:idx] + } + + exactMatches := []string{ + "LANG", + "LANGUAGE", + "TERM", + "COLORTERM", + "EDITOR", + "VISUAL", + "PAGER", + "LESS", + "LESSCHARSET", + "TZ", + } + + prefixMatches := []string{ + "LC_", + } + + for _, exact := range exactMatches { + if varName == exact { + return true + } + } + + for _, prefix := range prefixMatches { + if strings.HasPrefix(varName, prefix) { + return true + } + } + + return false +} + +// prepareSSHEnv prepares SSH protocol-specific environment variables +// These variables provide information about the SSH connection itself +func prepareSSHEnv(session ssh.Session) []string { + remoteAddr := session.RemoteAddr() + localAddr := session.LocalAddr() + + remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String()) + if err != nil { + remoteHost = remoteAddr.String() + remotePort = "0" + } + + localHost, localPort, err := net.SplitHostPort(localAddr.String()) + if err != nil { + localHost = localAddr.String() + localPort = strconv.Itoa(InternalSSHPort) + } + + return []string{ + // SSH_CLIENT format: "client_ip client_port server_port" + fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort), + // SSH_CONNECTION format: "client_ip client_port server_ip server_port" + fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort), + } +} diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go new file mode 100644 index 000000000..20930c721 --- /dev/null +++ b/client/ssh/server/test.go @@ -0,0 +1,45 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" +) + +func StartTestServer(t *testing.T, server *Server) string { + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort := netip.MustParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + return actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + return "" +} diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go new file mode 100644 index 000000000..799882cbb --- /dev/null +++ b/client/ssh/server/user_utils.go @@ -0,0 +1,411 @@ +package server + +import ( + "errors" + "fmt" + "os" + "os/user" + "runtime" + "strings" + + log "github.com/sirupsen/logrus" +) + +var ( + ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges") + ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges") +) + +// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.) +func isPlatformUnix() bool { + return getCurrentOS() != "windows" +} + +// Dependency injection variables for testing - allows mocking dynamic runtime checks +var ( + getCurrentUser = user.Current + lookupUser = user.Lookup + getCurrentOS = func() string { return runtime.GOOS } + getIsProcessPrivileged = isCurrentProcessPrivileged + + getEuid = os.Geteuid +) + +const ( + // FeatureSSHLogin represents SSH login operations for privilege checking + FeatureSSHLogin = "SSH login" + // FeatureSFTP represents SFTP operations for privilege checking + FeatureSFTP = "SFTP" +) + +// PrivilegeCheckRequest represents a privilege check request +type PrivilegeCheckRequest struct { + // Username being requested (empty = current user) + RequestedUsername string + FeatureSupportsUserSwitch bool // Does this feature/operation support user switching? + FeatureName string +} + +// PrivilegeCheckResult represents the result of a privilege check +type PrivilegeCheckResult struct { + // Allowed indicates whether the privilege check passed + Allowed bool + // User is the effective user to use for the operation (nil if not allowed) + User *user.User + // Error contains the reason for denial (nil if allowed) + Error error + // UsedFallback indicates we fell back to current user instead of requested user. + // This happens on Unix when running as an unprivileged user (e.g., in containers) + // where there's no point in user switching since we lack privileges anyway. + // When true, all privilege checks have already been performed and no additional + // privilege dropping or root checks are needed - the current user is the target. + UsedFallback bool + // RequiresUserSwitching indicates whether user switching will actually occur + // (false for fallback cases where no actual switching happens) + RequiresUserSwitching bool +} + +// CheckPrivileges performs comprehensive privilege checking for all SSH features. +// This is the single source of truth for privilege decisions across the SSH server. +func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult { + context, err := s.buildPrivilegeCheckContext(req.FeatureName) + if err != nil { + return PrivilegeCheckResult{Allowed: false, Error: err} + } + + // Handle empty username case - but still check root access controls + if req.RequestedUsername == "" { + if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: context.currentUser.Username}, + } + } + return PrivilegeCheckResult{ + Allowed: true, + User: context.currentUser, + RequiresUserSwitching: false, + } + } + + return s.checkUserRequest(context, req) +} + +// buildPrivilegeCheckContext gathers all the context needed for privilege checking +func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) { + currentUser, err := getCurrentUser() + if err != nil { + return nil, fmt.Errorf("get current user for %s: %w", featureName, err) + } + + s.mu.RLock() + allowRoot := s.allowRootLogin + s.mu.RUnlock() + + return &privilegeCheckContext{ + currentUser: currentUser, + currentUserPrivileged: getIsProcessPrivileged(), + allowRoot: allowRoot, + }, nil +} + +// checkUserRequest handles normal privilege checking flow for specific usernames +func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult { + if !ctx.currentUserPrivileged && isPlatformUnix() { + log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)", + ctx.currentUser.Username, req.FeatureName, req.RequestedUsername) + return PrivilegeCheckResult{ + Allowed: true, + User: ctx.currentUser, + UsedFallback: true, + RequiresUserSwitching: false, + } + } + + resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername) + if err != nil { + // Calculate if user switching would be required even if lookup failed + needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username) + return PrivilegeCheckResult{ + Allowed: false, + Error: err, + RequiresUserSwitching: needsUserSwitching, + } + } + + needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser) + + if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: resolvedUser.Username}, + RequiresUserSwitching: needsUserSwitching, + } + } + + if needsUserSwitching && !req.FeatureSupportsUserSwitch { + return PrivilegeCheckResult{ + Allowed: false, + Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName), + RequiresUserSwitching: needsUserSwitching, + } + } + + return PrivilegeCheckResult{ + Allowed: true, + User: resolvedUser, + RequiresUserSwitching: needsUserSwitching, + } +} + +// resolveRequestedUser resolves a username to its canonical user identity +func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) { + if requestedUsername == "" { + return getCurrentUser() + } + + if err := validateUsername(requestedUsername); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err) + } + + u, err := lookupUser(requestedUsername) + if err != nil { + return nil, &UserNotFoundError{Username: requestedUsername, Cause: err} + } + return u, nil +} + +// isSameResolvedUser compares two resolved user identities +func isSameResolvedUser(user1, user2 *user.User) bool { + if user1 == nil || user2 == nil { + return user1 == user2 + } + return user1.Uid == user2.Uid +} + +// privilegeCheckContext holds all context needed for privilege checking +type privilegeCheckContext struct { + currentUser *user.User + currentUserPrivileged bool + allowRoot bool +} + +// isSameUser checks if two usernames refer to the same user +// SECURITY: This function must be conservative - it should only return true +// when we're certain both usernames refer to the exact same user identity +func isSameUser(requestedUsername, currentUsername string) bool { + // Empty requested username means current user + if requestedUsername == "" { + return true + } + + // Exact match (most common case) + if getCurrentOS() == "windows" { + if strings.EqualFold(requestedUsername, currentUsername) { + return true + } + } else { + if requestedUsername == currentUsername { + return true + } + } + + // Windows domain resolution: only allow domain stripping when comparing + // a bare username against the current user's domain-qualified name + if getCurrentOS() == "windows" { + return isWindowsSameUser(requestedUsername, currentUsername) + } + + return false +} + +// isWindowsSameUser handles Windows-specific user comparison with domain logic +func isWindowsSameUser(requestedUsername, currentUsername string) bool { + // Extract domain and username parts + extractParts := func(name string) (domain, user string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(name, `\`); idx != -1 { + return name[:idx], name[idx+1:] + } + // Handle user@domain.com format + if idx := strings.Index(name, "@"); idx != -1 { + return name[idx+1:], name[:idx] + } + // No domain specified - local machine + return "", name + } + + reqDomain, reqUser := extractParts(requestedUsername) + curDomain, curUser := extractParts(currentUsername) + + // Case-insensitive username comparison + if !strings.EqualFold(reqUser, curUser) { + return false + } + + // If requested username has no domain, it refers to local machine user + // Allow this to match the current user regardless of current user's domain + if reqDomain == "" { + return true + } + + // If both have domains, they must match exactly (case-insensitive) + return strings.EqualFold(reqDomain, curDomain) +} + +// SetAllowRootLogin configures root login access +func (s *Server) SetAllowRootLogin(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRootLogin = allow +} + +// userNameLookup performs user lookup with root login permission check +func (s *Server) userNameLookup(username string) (*user.User, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return nil, result.Error + } + + return result.User, nil +} + +// userPrivilegeCheck performs user lookup with full privilege check result +func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return result, result.Error + } + + return result, nil +} + +// isPrivilegedUsername checks if the given username represents a privileged user across platforms. +// On Unix: root +// On Windows: Administrator, SYSTEM (case-insensitive) +// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com" +func isPrivilegedUsername(username string) bool { + if getCurrentOS() != "windows" { + return username == "root" + } + + bareUsername := username + // Handle Windows domain format: DOMAIN\username + if idx := strings.LastIndex(username, `\`); idx != -1 { + bareUsername = username[idx+1:] + } + // Handle email-style format: username@domain.com + if idx := strings.Index(bareUsername, "@"); idx != -1 { + bareUsername = bareUsername[:idx] + } + + return isWindowsPrivilegedUser(bareUsername) +} + +// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account +func isWindowsPrivilegedUser(bareUsername string) bool { + // common privileged usernames (case insensitive) + privilegedNames := []string{ + "administrator", + "admin", + "root", + "system", + "localsystem", + "networkservice", + "localservice", + } + + usernameLower := strings.ToLower(bareUsername) + for _, privilegedName := range privilegedNames { + if usernameLower == privilegedName { + return true + } + } + + // computer accounts (ending with $) are not privileged by themselves + // They only gain privileges through group membership or specific SIDs + + if targetUser, err := lookupUser(bareUsername); err == nil { + return isWindowsPrivilegedSID(targetUser.Uid) + } + + return false +} + +// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account +func isWindowsPrivilegedSID(sid string) bool { + privilegedSIDs := []string{ + "S-1-5-18", // Local System (SYSTEM) + "S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE) + "S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE) + "S-1-5-32-544", // Administrators group (BUILTIN\Administrators) + "S-1-5-500", // Built-in Administrator account (local machine RID 500) + } + + for _, privilegedSID := range privilegedSIDs { + if sid == privilegedSID { + return true + } + } + + // Check for domain administrator accounts (RID 500 in any domain) + // Format: S-1-5-21-domain-domain-domain-500 + // This is reliable as RID 500 is reserved for the domain Administrator account + if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") { + return true + } + + // Check for other well-known privileged RIDs in domain contexts + // RID 512 = Domain Admins group, RID 516 = Domain Controllers group + if strings.HasPrefix(sid, "S-1-5-21-") { + if strings.HasSuffix(sid, "-512") || // Domain Admins group + strings.HasSuffix(sid, "-516") || // Domain Controllers group + strings.HasSuffix(sid, "-519") { // Enterprise Admins group + return true + } + } + + return false +} + +// isCurrentProcessPrivileged checks if the current process is running with elevated privileges. +// On Unix systems, this means running as root (UID 0). +// On Windows, this means running as Administrator or SYSTEM. +func isCurrentProcessPrivileged() bool { + if getCurrentOS() == "windows" { + return isWindowsElevated() + } + return getEuid() == 0 +} + +// isWindowsElevated checks if the current process is running with elevated privileges on Windows +func isWindowsElevated() bool { + currentUser, err := getCurrentUser() + if err != nil { + log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err) + return false + } + + if isWindowsPrivilegedSID(currentUser.Uid) { + log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid) + return true + } + + if isPrivilegedUsername(currentUser.Username) { + log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username) + return true + } + + log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid) + return false +} diff --git a/client/ssh/server/user_utils_js.go b/client/ssh/server/user_utils_js.go new file mode 100644 index 000000000..163b24c6c --- /dev/null +++ b/client/ssh/server/user_utils_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// validateUsername is not supported on JS/WASM +func validateUsername(_ string) error { + return errNotSupported +} diff --git a/client/ssh/server/user_utils_test.go b/client/ssh/server/user_utils_test.go new file mode 100644 index 000000000..637dc10d0 --- /dev/null +++ b/client/ssh/server/user_utils_test.go @@ -0,0 +1,908 @@ +package server + +import ( + "errors" + "os/user" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Test helper functions +func createTestUser(username, uid, gid, homeDir string) *user.User { + return &user.User{ + Uid: uid, + Gid: gid, + Username: username, + Name: username, + HomeDir: homeDir, + } +} + +// Test dependency injection setup - injects platform dependencies to test real logic +func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() { + // Store originals + originalGetCurrentUser := getCurrentUser + originalLookupUser := lookupUser + originalGetCurrentOS := getCurrentOS + originalGetEuid := getEuid + + // Reset caches to ensure clean test state + + // Set test values - inject platform dependencies + getCurrentUser = func() (*user.User, error) { + return currentUser, currentUserErr + } + + lookupUser = func(username string) (*user.User, error) { + if err, exists := lookupErrors[username]; exists { + return nil, err + } + if userObj, exists := lookupUsers[username]; exists { + return userObj, nil + } + return nil, errors.New("user: unknown user " + username) + } + + getCurrentOS = func() string { + return os + } + + getEuid = func() int { + return euid + } + + // Mock privilege detection based on the test user + getIsProcessPrivileged = func() bool { + if currentUser == nil { + return false + } + // Check both username and SID for Windows systems + if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) { + return true + } + return isPrivilegedUsername(currentUser.Username) + } + + // Return cleanup function + return func() { + getCurrentUser = originalGetCurrentUser + lookupUser = originalLookupUser + getCurrentOS = originalGetCurrentOS + getEuid = originalGetEuid + + getIsProcessPrivileged = isCurrentProcessPrivileged + + // Reset caches after test + } +} + +func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + requestedUsername string + featureSupportsUserSwitch bool + allowRoot bool + lookupUsers map[string]*user.User + expectedAllowed bool + expectedRequiresSwitch bool + }{ + { + name: "linux_root_can_switch_to_alice", + os: "linux", + euid: 0, // Root process + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "linux_non_root_fallback_to_current_user", + os: "linux", + euid: 1000, // Non-root process + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to current user (alice) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + { + name: "windows_admin_can_switch_to_alice", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_no_fallback_hard_failure", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, // Different user was requested + }, + // Comprehensive test matrix: non-root linux with different allowRoot settings + { + name: "linux_non_root_request_root_allowRoot_false", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Fallback allows access regardless of root setting + expectedRequiresSwitch: false, // Fallback case, no switching + }, + { + name: "linux_non_root_request_root_allowRoot_true", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to alice (non-privileged process) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + // Windows admin test matrix + { + name: "windows_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Windows user switching should work like Unix + expectedRequiresSwitch: true, + }, + // Windows non-admin test matrix + { + name: "windows_non_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence) + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // SYSTEM can switch to root + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, + }, + + // Feature doesn't support user switching scenarios + { + name: "linux_root_feature_no_user_switching_same_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "root", // Same user + featureSupportsUserSwitch: false, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Same user should work regardless of feature support + expectedRequiresSwitch: false, + }, + { + name: "linux_root_feature_no_user_switching_different_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Empty username (current user) scenarios + { + name: "linux_non_root_current_user_empty_username", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, + }, + { + name: "linux_root_current_user_empty_username_root_not_allowed", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "", // Empty = current user (root) + featureSupportsUserSwitch: true, + allowRoot: false, // Root not allowed + expectedAllowed: false, // Should deny root even when it's current user + expectedRequiresSwitch: false, + }, + + // User not found scenarios + { + name: "linux_root_user_not_found", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "nonexistent", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{}, // No users defined = user not found + expectedAllowed: false, // Should fail due to user not found + expectedRequiresSwitch: true, + }, + + // Windows feature doesn't support user switching + { + name: "windows_admin_feature_no_user_switching_different_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Windows regular user scenarios (non-admin) + { + name: "windows_regular_user_same_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "alice", // Same user + featureSupportsUserSwitch: true, + allowRoot: false, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, // Regular user accessing themselves should work + expectedRequiresSwitch: false, // No switching for same user + }, + { + name: "windows_regular_user_empty_username", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, // No switching for current user + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies to test real logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed) + assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching) + }) + } +} + +func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Fallback mechanism is Unix-specific") + } + + // Create test scenario where fallback should occur + server := &Server{allowRootLogin: true} + + // Mock dependencies to simulate non-privileged user + originalGetCurrentUser := getCurrentUser + originalGetIsProcessPrivileged := getIsProcessPrivileged + + defer func() { + getCurrentUser = originalGetCurrentUser + getIsProcessPrivileged = originalGetIsProcessPrivileged + + }() + + // Set up mocks for fallback scenario + getCurrentUser = func() (*user.User, error) { + return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil + } + getIsProcessPrivileged = func() bool { return false } // Non-privileged + + // Request different user - should fallback + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "alice", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + // Verify fallback occurred + assert.True(t, result.Allowed, "Should allow with fallback") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.Equal(t, "netbird", result.User.Username, "Should return current user") + assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used") + + // Key assertion: When UsedFallback is true, no privilege dropping should be needed + // because all privilege checks have already been performed and we're using current user + t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed", + result.User.Username) +} + +func TestPrivilegedUsernameDetection(t *testing.T) { + tests := []struct { + name string + username string + platform string + privileged bool + }{ + // Unix/Linux tests + {"unix_root", "root", "linux", true}, + {"unix_regular_user", "alice", "linux", false}, + {"unix_root_capital", "Root", "linux", false}, // Case-sensitive + + // Windows tests + {"windows_administrator", "Administrator", "windows", true}, + {"windows_system", "SYSTEM", "windows", true}, + {"windows_admin", "admin", "windows", true}, + {"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive + {"windows_domain_admin", "DOMAIN\\Administrator", "windows", true}, + {"windows_email_admin", "admin@domain.com", "windows", true}, + {"windows_regular_user", "alice", "windows", false}, + {"windows_domain_user", "DOMAIN\\alice", "windows", false}, + {"windows_localsystem", "localsystem", "windows", true}, + {"windows_networkservice", "networkservice", "windows", true}, + {"windows_localservice", "localservice", "windows", true}, + + // Computer accounts (these depend on current user context in real implementation) + {"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged + {"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account + + // Cross-platform + {"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the platform for this test + cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil) + defer cleanup() + + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.privileged, result) + }) + } +} + +func TestWindowsPrivilegedSIDDetection(t *testing.T) { + tests := []struct { + name string + sid string + privileged bool + description string + }{ + // Well-known system accounts + {"system_account", "S-1-5-18", true, "Local System (SYSTEM)"}, + {"local_service", "S-1-5-19", true, "Local Service"}, + {"network_service", "S-1-5-20", true, "Network Service"}, + {"administrators_group", "S-1-5-32-544", true, "Administrators group"}, + {"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"}, + + // Domain accounts + {"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"}, + {"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"}, + {"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"}, + {"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"}, + + // Regular users + {"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"}, + {"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"}, + {"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"}, + + // Groups that are not privileged + {"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"}, + {"power_users", "S-1-5-32-547", false, "Power Users group"}, + + // Invalid SIDs + {"malformed_sid", "S-1-5-invalid", false, "Malformed SID"}, + {"empty_sid", "", false, "Empty SID"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isWindowsPrivilegedSID(tt.sid) + assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid) + }) + } +} + +func TestIsSameUser(t *testing.T) { + tests := []struct { + name string + user1 string + user2 string + os string + expected bool + }{ + // Basic cases + {"same_username", "alice", "alice", "linux", true}, + {"different_username", "alice", "bob", "linux", false}, + + // Linux (no domain processing) + {"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false}, + {"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false}, + {"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true}, + + // Windows (with domain processing) - Note: parameter order is (requested, current, os, expected) + {"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user + {"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user + {"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users + {"windows_case_insensitive", "Alice", "alice", "windows", true}, + {"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up OS mock + cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil) + defer cleanup() + + result := isSameUser(tt.user1, tt.user2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUsernameValidation_Unix(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Unix/POSIX) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + + // Invalid usernames (Unix/POSIX) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames + {"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +func TestUsernameValidation_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Windows) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + {"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this + {"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format + {"valid_email_username", "user@domain.com", false, ""}, // Windows email format + {"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format + + // Invalid usernames (Windows) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"username_with_brackets", "user[name]", true, "invalid characters"}, + {"username_with_colon", "user:name", true, "invalid characters"}, + {"username_with_semicolon", "user;name", true, "invalid characters"}, + {"username_with_equals", "user=name", true, "invalid characters"}, + {"username_with_comma", "user,name", true, "invalid characters"}, + {"username_with_plus", "user+name", true, "invalid characters"}, + {"username_with_asterisk", "user*name", true, "invalid characters"}, + {"username_with_question", "user?name", true, "invalid characters"}, + {"username_with_angles", "user", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_ending_with_period", "user.", true, "cannot end with a period"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +// Test real-world integration scenarios with actual platform capabilities +func TestCheckPrivileges_RealWorldScenarios(t *testing.T) { + tests := []struct { + name string + feature string + featureSupportsUserSwitch bool + requestedUsername string + allowRoot bool + expectedBehaviorPattern string + }{ + {"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"}, + {"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"}, + {"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"}, + {"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"}, + {"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to ensure consistent test behavior across environments + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"), + }, + nil, + ) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: tt.feature, + }) + + switch tt.expectedBehaviorPattern { + case "should_allow_current_user": + assert.True(t, result.Allowed, "Should allow current user access") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + case "should_deny_root": + assert.False(t, result.Allowed, "Should deny root when not allowed") + assert.Contains(t, result.Error.Error(), "root", "Should mention root in error") + case "should_deny_switching": + assert.False(t, result.Allowed, "Should deny when feature doesn't support switching") + assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error") + } + }) + } +} + +// Test with actual platform capabilities - no mocking +func TestCheckPrivileges_ActualPlatform(t *testing.T) { + // This test uses the REAL platform capabilities + server := &Server{allowRootLogin: true} + + // Test current user access - should always work + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "", // Current user + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.True(t, result.Allowed, "Current user should always be allowed") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + assert.NotNil(t, result.User, "Should return current user") + + // Test user switching capability based on actual platform + actualIsPrivileged := isCurrentProcessPrivileged() // REAL check + actualOS := runtime.GOOS // REAL check + + t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v", + actualOS, actualIsPrivileged, actualIsPrivileged) + + // Test requesting different user + result = server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "nonexistentuser", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + switch { + case actualOS == "windows": + // Windows supports user switching but should fail on nonexistent user + assert.False(t, result.Allowed, "Windows should deny nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "not found", + "Should indicate user not found") + case !actualIsPrivileged: + // Non-privileged Unix processes should fallback to current user + assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user") + assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.NotNil(t, result.User, "Should return current user") + default: + // Privileged Unix processes should attempt user lookup + assert.False(t, result.Allowed, "Should fail due to nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "nonexistentuser", + "Should indicate user not found") + } +} + +// Test platform detection logic with dependency injection +func TestPlatformLogic_DependencyInjection(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + expectedIsProcessPrivileged bool + expectedSupportsUserSwitching bool + }{ + { + name: "linux_root_process", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, + }, + { + name: "linux_non_root_process", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + { + name: "windows_admin_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, // Windows supports user switching when privileged + }, + { + name: "windows_regular_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies and test REAL logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil) + defer cleanup() + + // Test the actual functions with injected dependencies + actualIsPrivileged := isCurrentProcessPrivileged() + actualSupportsUserSwitching := actualIsPrivileged + + assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged, + "isCurrentProcessPrivileged() result mismatch") + assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching, + "supportsUserSwitching() result mismatch") + + t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username) + t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v", + actualIsPrivileged, actualSupportsUserSwitching) + }) + } +} + +func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) { + // Test Windows elevated user switching scenarios with simplified privilege logic + tests := []struct { + name string + currentUser *user.User + requestedUsername string + allowRoot bool + expectedAllowed bool + expectedErrorContains string + }{ + { + name: "windows_admin_can_switch_to_alice", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_non_admin_can_try_switch", + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"), + requestedUsername: "bob", + allowRoot: true, + expectedAllowed: true, // Privilege check allows it, OS will reject during execution + }, + { + name: "windows_system_can_switch_to_alice", + currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_admin_root_not_allowed", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "root", + allowRoot: false, + expectedAllowed: false, + expectedErrorContains: "privileged user login is disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test dependencies with Windows OS and specified privileges + lookupUsers := map[string]*user.User{ + tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername), + } + cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed, + "Privilege check result should match expected for %s", tt.name) + + if !tt.expectedAllowed && tt.expectedErrorContains != "" { + assert.NotNil(t, result.Error, "Should have error when not allowed") + assert.Contains(t, result.Error.Error(), tt.expectedErrorContains, + "Error should contain expected message") + } + + if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername { + assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user") + } + }) + } +} diff --git a/client/ssh/server/userswitching_js.go b/client/ssh/server/userswitching_js.go new file mode 100644 index 000000000..333c19259 --- /dev/null +++ b/client/ssh/server/userswitching_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// enableUserSwitching is not supported on JS/WASM +func enableUserSwitching() error { + return errNotSupported +} diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go new file mode 100644 index 000000000..bc1557419 --- /dev/null +++ b/client/ssh/server/userswitching_unix.go @@ -0,0 +1,260 @@ +//go:build unix + +package server + +import ( + "errors" + "fmt" + "net" + "net/netip" + "os" + "os/exec" + "os/user" + "regexp" + "runtime" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// POSIX portable filename character set regex: [a-zA-Z0-9._-] +// First character cannot be hyphen (POSIX requirement) +var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`) + +// validateUsername validates that a username conforms to POSIX standards with security considerations +func validateUsername(username string) error { + if username == "" { + return errors.New("username cannot be empty") + } + + // POSIX allows up to 256 characters, but practical limit is 32 for compatibility + if len(username) > 32 { + return errors.New("username too long (max 32 characters)") + } + + if !posixUsernameRegex.MatchString(username) { + return errors.New("username contains invalid characters (must match POSIX portable filename character set)") + } + + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + // Warn if username is fully numeric (can cause issues with UID/username ambiguity) + if isFullyNumeric(username) { + log.Warnf("fully numeric username '%s' may cause issues with some commands", username) + } + + return nil +} + +// isFullyNumeric checks if username contains only digits +func isFullyNumeric(username string) bool { + for _, char := range username { + if char < '0' || char > '9' { + return false + } + } + return true +} + +// createPtyLoginCommand creates a Pty command using login for privileged processes +func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("get login command: %w", err) + } + + execCmd := exec.CommandContext(session.Context(), loginPath, args...) + execCmd.Dir = localUser.HomeDir + execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return execCmd, nil +} + +// getLoginCmd returns the login command and args for privileged Pty user switching +func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) { + loginPath, err := exec.LookPath("login") + if err != nil { + return "", nil, fmt.Errorf("login command not available: %w", err) + } + + addrPort, err := netip.ParseAddrPort(remoteAddr.String()) + if err != nil { + return "", nil, fmt.Errorf("parse remote address: %w", err) + } + + switch runtime.GOOS { + case "linux": + p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String()) + return p, a, nil + case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": + return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil + default: + return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS) + } +} + +// getLinuxLoginCmd returns the login command for Linux systems. +// Handles differences between util-linux and shadow-utils login implementations. +func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) { + // Special handling for Arch Linux without /etc/pam.d/remote + var loginArgs []string + if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { + loginArgs = []string{"-f", username, "-p"} + } else { + loginArgs = []string{"-f", username, "-h", remoteIP, "-p"} + } + + // util-linux login requires setsid -c to create a new session and set the + // controlling terminal. Without this, vhangup() kills the parent process. + // See https://bugs.debian.org/1078023 for details. + // TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec() + // to avoid external setsid dependency. + if !s.loginIsUtilLinux { + return loginPath, loginArgs + } + + setsidPath, err := exec.LookPath("setsid") + if err != nil { + log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err) + return loginPath, loginArgs + } + + args := append([]string{"-w", "-c", loginPath}, loginArgs...) + return setsidPath, args +} + +// fileExists checks if a file exists +func (s *Server) fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +// parseUserCredentials extracts numeric UID, GID, and supplementary groups +func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) { + uid64, err := strconv.ParseUint(localUser.Uid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err) + } + uid := uint32(uid64) + + gid64, err := strconv.ParseUint(localUser.Gid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err) + } + gid := uint32(gid64) + + groups, err := s.getSupplementaryGroups(localUser.Username) + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups = []uint32{gid} + } + + return uid, gid, groups, nil +} + +// getSupplementaryGroups retrieves supplementary group IDs for a user +func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { + u, err := user.Lookup(username) + if err != nil { + return nil, fmt.Errorf("lookup user %s: %w", username, err) + } + + groupIDStrings, err := u.GroupIds() + if err != nil { + return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + } + + groups := make([]uint32, len(groupIDStrings)) + for i, gidStr := range groupIDStrings { + gid64, err := strconv.ParseUint(gidStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + } + groups[i] = uint32(gid64) + } + + return groups, nil +} + +// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. +// Returns the command and a cleanup function (no-op on Unix). +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + if err := validateUsername(localUser.Username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + + uid, gid, groups, err := s.parseUserCredentials(localUser) + if err != nil { + return nil, nil, fmt.Errorf("parse user credentials: %w", err) + } + privilegeDropper := NewPrivilegeDropper() + config := ExecutorConfig{ + UID: uid, + GID: gid, + Groups: groups, + WorkingDir: localUser.HomeDir, + Shell: getUserShell(localUser.Uid), + Command: session.RawCommand(), + PTY: hasPty, + } + + cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config) + return cmd, func() {}, err +} + +// enableUserSwitching is a no-op on Unix systems +func enableUserSwitching() error { + return nil +} + +// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results +func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, errors.New("no user in privilege result") + } + + if privilegeResult.UsedFallback { + return s.createDirectPtyCommand(session, localUser, ptyReq), nil + } + + return s.createPtyLoginCommand(localUser, ptyReq, session) +} + +// createDirectPtyCommand creates a direct Pty command without privilege dropping +func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd { + log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username) + + shell := getUserShell(localUser.Uid) + args := s.getShellCommandArgs(shell, session.RawCommand()) + + cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd.Dir = localUser.HomeDir + cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return cmd +} + +// preparePtyEnv prepares environment variables for Pty execution +func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + env = append(env, fmt.Sprintf("TERM=%s", termType)) + + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go new file mode 100644 index 000000000..5a5f75fa4 --- /dev/null +++ b/client/ssh/server/userswitching_windows.go @@ -0,0 +1,274 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os/exec" + "os/user" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// validateUsername validates Windows usernames according to SAM Account Name rules +func validateUsername(username string) error { + if username == "" { + return fmt.Errorf("username cannot be empty") + } + + usernameToValidate := extractUsernameFromDomain(username) + + if err := validateUsernameLength(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameCharacters(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameFormat(usernameToValidate); err != nil { + return err + } + + return nil +} + +// extractUsernameFromDomain extracts the username part from domain\username or username@domain format +func extractUsernameFromDomain(username string) string { + if idx := strings.LastIndex(username, `\`); idx != -1 { + return username[idx+1:] + } + if idx := strings.Index(username, "@"); idx != -1 { + return username[:idx] + } + return username +} + +// validateUsernameLength checks if username length is within Windows limits +func validateUsernameLength(username string) error { + if len(username) > 20 { + return fmt.Errorf("username too long (max 20 characters for Windows)") + } + return nil +} + +// validateUsernameCharacters checks for invalid characters in Windows usernames +func validateUsernameCharacters(username string) error { + invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'} + for _, char := range username { + for _, invalid := range invalidChars { + if char == invalid { + return fmt.Errorf("username contains invalid characters") + } + } + if char < 32 || char == 127 { + return fmt.Errorf("username contains control characters") + } + } + return nil +} + +// validateUsernameFormat checks for invalid username formats and patterns +func validateUsernameFormat(username string) error { + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + if strings.HasSuffix(username, ".") { + return fmt.Errorf("username cannot end with a period") + } + + return nil +} + +// createExecutorCommand creates a command using Windows executor for privilege dropping. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + username, _ := s.parseUsername(localUser.Username) + if err := validateUsername(username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) + } + + return s.createUserSwitchCommand(localUser, session, hasPty) +} + +// createUserSwitchCommand creates a command with Windows user switching. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { + username, domain := s.parseUsername(localUser.Username) + + shell := getUserShell(localUser.Uid) + + rawCmd := session.RawCommand() + var command string + if rawCmd != "" { + command = rawCmd + } + + config := WindowsExecutorConfig{ + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, + Interactive: interactive || (rawCmd == ""), + } + + dropper := NewPrivilegeDropper() + cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) + if err != nil { + return nil, nil, err + } + + cleanup := func() { + if token != 0 { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + } + } + + return cmd, cleanup, nil +} + +// parseUsername extracts username and domain from a Windows username +func (s *Server) parseUsername(fullUsername string) (username, domain string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(fullUsername, `\`); idx != -1 { + domain = fullUsername[:idx] + username = fullUsername[idx+1:] + return username, domain + } + + // Handle username@domain format + if username, domain, ok := strings.Cut(fullUsername, "@"); ok { + return username, domain + } + + // Local user (no domain) + return fullUsername, "." +} + +// hasPrivilege checks if the current process has a specific privilege +func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return false, fmt.Errorf("lookup privilege value: %w", err) + } + + var returnLength uint32 + err := windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + nil, // null buffer to get size + 0, + &returnLength, + ) + + if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return false, fmt.Errorf("get token information size: %w", err) + } + + buffer := make([]byte, returnLength) + err = windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + &buffer[0], + returnLength, + &returnLength, + ) + if err != nil { + return false, fmt.Errorf("get token information: %w", err) + } + + privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0])) + + // Check if the privilege is present and enabled + for i := uint32(0); i < privileges.PrivilegeCount; i++ { + privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer( + uintptr(unsafe.Pointer(&privileges.Privileges[0])) + + uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}), + )) + if privilege.Luid == luid { + return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil + } + } + + return false, nil +} + +// enablePrivilege enables a specific privilege for the current process token +// This is required because privileges like SeAssignPrimaryTokenPrivilege are present +// but disabled by default, even for the SYSTEM account +func enablePrivilege(token windows.Handle, privilegeName string) error { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err) + } + + privileges := windows.Tokenprivileges{ + PrivilegeCount: 1, + Privileges: [1]windows.LUIDAndAttributes{ + { + Luid: luid, + Attributes: windows.SE_PRIVILEGE_ENABLED, + }, + }, + } + + err := windows.AdjustTokenPrivileges( + windows.Token(token), + false, + &privileges, + 0, + nil, + nil, + ) + if err != nil { + return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err) + } + + hasPriv, err := hasPrivilege(token, privilegeName) + if err != nil { + return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err) + } + if !hasPriv { + return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName) + } + + log.Debugf("Successfully enabled privilege %s for current process", privilegeName) + return nil +} + +// enableUserSwitching enables required privileges for Windows user switching +func enableUserSwitching() error { + process := windows.CurrentProcess() + + var token windows.Token + err := windows.OpenProcessToken( + process, + windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, + &token, + ) + if err != nil { + return fmt.Errorf("open process token: %w", err) + } + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("Failed to close process token: %v", err) + } + }() + + if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil { + return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err) + } + log.Infof("Windows user switching privileges enabled successfully") + return nil +} diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go new file mode 100644 index 000000000..0f3659ffe --- /dev/null +++ b/client/ssh/server/winpty/conpty.go @@ -0,0 +1,487 @@ +//go:build windows + +package winpty + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +var ( + ErrEmptyEnvironment = errors.New("empty environment") +) + +const ( + extendedStartupInfoPresent = 0x00080000 + createUnicodeEnvironment = 0x00000400 + procThreadAttributePseudoConsole = 0x00020016 + + PowerShellCommandFlag = "-Command" + + errCloseInputRead = "close input read handle: %v" + errCloseConPtyCleanup = "close ConPty handle during cleanup" +) + +// PtyConfig holds configuration for Pty execution. +type PtyConfig struct { + Shell string + Command string + Width int + Height int + WorkingDir string +} + +// UserConfig holds user execution configuration. +type UserConfig struct { + Token windows.Handle + Environment []string +} + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") + procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList") + procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute") + procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList") +) + +// ExecutePtyWithUserToken executes a command with ConPty using user token. +func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { + args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) + commandLine := buildCommandLine(args) + + config := ExecutionConfig{ + Pty: ptyConfig, + User: userConfig, + Session: session, + Context: ctx, + } + + return executeConPtyWithConfig(commandLine, config) +} + +// ExecutionConfig holds all execution configuration. +type ExecutionConfig struct { + Pty PtyConfig + User UserConfig + Session ssh.Session + Context context.Context +} + +// executeConPtyWithConfig creates ConPty and executes process with configuration. +func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error { + ctx := config.Context + session := config.Session + width := config.Pty.Width + height := config.Pty.Height + userToken := config.User.Token + userEnv := config.User.Environment + workingDir := config.Pty.WorkingDir + + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + return fmt.Errorf("create ConPty pipes: %w", err) + } + + hPty, err := createConPty(width, height, inputRead, outputWrite) + if err != nil { + return fmt.Errorf("create ConPty: %w", err) + } + + primaryToken, err := duplicateToPrimaryToken(userToken) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("duplicate to primary token: %w", err) + } + defer func() { + if err := windows.CloseHandle(primaryToken); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + siEx, err := setupConPtyStartupInfo(hPty) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("setup startup info: %w", err) + } + defer func() { + _, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList))) + }() + + pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("create process as user with ConPty: %w", err) + } + defer closeProcessInfo(pi) + + if err := windows.CloseHandle(inputRead); err != nil { + log.Debugf(errCloseInputRead, err) + } + if err := windows.CloseHandle(outputWrite); err != nil { + log.Debugf("close output write handle: %v", err) + } + + return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process) +} + +// createConPtyPipes creates input/output pipes for ConPty. +func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) { + if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil { + return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err) + } + + if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil { + if closeErr := windows.CloseHandle(inputRead); closeErr != nil { + log.Debugf(errCloseInputRead, closeErr) + } + if closeErr := windows.CloseHandle(inputWrite); closeErr != nil { + log.Debugf("close input write handle: %v", closeErr) + } + return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err) + } + + return inputRead, inputWrite, outputRead, outputWrite, nil +} + +// createConPty creates a Windows ConPty with the specified size and pipe handles. +func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) { + size := windows.Coord{X: int16(width), Y: int16(height)} + + var hPty windows.Handle + if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil { + return 0, fmt.Errorf("CreatePseudoConsole: %w", err) + } + + return hPty, nil +} + +// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes. +func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) { + var siEx windows.StartupInfoEx + siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx)) + + var attrListSize uintptr + ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize))) + if ret == 0 && attrListSize == 0 { + return nil, fmt.Errorf("get attribute list size") + } + + attrListBytes := make([]byte, attrListSize) + siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0])) + + ret, _, err := procInitializeProcThreadAttributeList.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 1, + 0, + uintptr(unsafe.Pointer(&attrListSize)), + ) + if ret == 0 { + return nil, fmt.Errorf("initialize attribute list: %w", err) + } + + ret, _, err = procUpdateProcThreadAttribute.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 0, + procThreadAttributePseudoConsole, + uintptr(hPty), + unsafe.Sizeof(hPty), + 0, + 0, + ) + if ret == 0 { + return nil, fmt.Errorf("update thread attribute: %w", err) + } + + return &siEx, nil +} + +// createConPtyProcess creates the actual process with ConPty. +func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) { + var pi windows.ProcessInformation + creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment) + + commandLinePtr, err := windows.UTF16PtrFromString(commandLine) + if err != nil { + return nil, fmt.Errorf("convert command line to UTF16: %w", err) + } + + envPtr, err := convertEnvironmentToUTF16(userEnv) + if err != nil { + return nil, err + } + + var workingDirPtr *uint16 + if workingDir != "" { + workingDirPtr, err = windows.UTF16PtrFromString(workingDir) + if err != nil { + return nil, fmt.Errorf("convert working directory to UTF16: %w", err) + } + } + + siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES + siEx.StartupInfo.StdInput = windows.Handle(0) + siEx.StartupInfo.StdOutput = windows.Handle(0) + siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput + + if userToken != windows.InvalidHandle { + err = windows.CreateProcessAsUser( + windows.Token(userToken), + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } else { + err = windows.CreateProcess( + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } + + if err != nil { + return nil, fmt.Errorf("create process: %w", err) + } + + return &pi, nil +} + +// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format. +func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) { + if len(userEnv) == 0 { + // Return nil pointer for empty environment - Windows API will inherit parent environment + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment + } + + var envUTF16 []uint16 + for _, envVar := range userEnv { + if envVar != "" { + utf16Str, err := windows.UTF16FromString(envVar) + if err != nil { + log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err) + continue + } + envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...) + envUTF16 = append(envUTF16, 0) + } + } + envUTF16 = append(envUTF16, 0) + + if len(envUTF16) > 0 { + return &envUTF16[0], nil + } + // Return nil pointer when no valid environment variables found + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment +} + +// duplicateToPrimaryToken converts an impersonation token to a primary token. +func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) { + var primaryToken windows.Handle + if err := windows.DuplicateTokenEx( + windows.Token(token), + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + (*windows.Token)(&primaryToken), + ); err != nil { + return 0, fmt.Errorf("duplicate token: %w", err) + } + return primaryToken, nil +} + +// SessionExiter provides the Exit method for reporting process exit status. +type SessionExiter interface { + Exit(code int) error +} + +// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers. +func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error { + if err := ctx.Err(); err != nil { + return err + } + + var wg sync.WaitGroup + startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer) + + processErr := waitForProcess(ctx, process) + if processErr != nil { + return processErr + } + + var exitCode uint32 + if err := windows.GetExitCodeProcess(process, &exitCode); err != nil { + log.Debugf("get exit code: %v", err) + } else { + if err := session.Exit(int(exitCode)); err != nil { + log.Debugf("report exit code: %v", err) + } + } + + // Clean up in the original order after process completes + if err := reader.Close(); err != nil { + log.Debugf("close reader: %v", err) + } + + ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)) + if ret == 0 { + log.Debugf("close ConPty handle: %v", err) + } + + wg.Wait() + + if err := windows.CloseHandle(outputRead); err != nil { + log.Debugf("close output read handle: %v", err) + } + + return nil +} + +// startIOBridging starts the I/O bridging goroutines. +func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) { + wg.Add(2) + + // Input: reader (SSH session) -> inputWrite (ConPty) + go func() { + defer wg.Done() + defer func() { + if err := windows.CloseHandle(inputWrite); err != nil { + log.Debugf("close input write handle in goroutine: %v", err) + } + }() + + if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil { + log.Debugf("input copy ended with error: %v", err) + } + }() + + // Output: outputRead (ConPty) -> writer (SSH session) + go func() { + defer wg.Done() + if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil { + log.Debugf("output copy ended with error: %v", err) + } + }() +} + +// waitForProcess waits for process completion with context cancellation. +func waitForProcess(ctx context.Context, process windows.Handle) error { + if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil { + return fmt.Errorf("wait for process %d: %w", process, err) + } + return nil +} + +// buildShellArgs builds shell arguments for ConPty execution. +func buildShellArgs(shell, command string) []string { + if command != "" { + return []string{shell, PowerShellCommandFlag, command} + } + return []string{shell} +} + +// buildCommandLine builds a Windows command line from arguments using proper escaping. +func buildCommandLine(args []string) string { + if len(args) == 0 { + return "" + } + + var result strings.Builder + for i, arg := range args { + if i > 0 { + result.WriteString(" ") + } + result.WriteString(syscall.EscapeArg(arg)) + } + return result.String() +} + +// closeHandles closes multiple Windows handles. +func closeHandles(handles ...windows.Handle) { + for _, handle := range handles { + if handle != windows.InvalidHandle { + if err := windows.CloseHandle(handle); err != nil { + log.Debugf("close handle: %v", err) + } + } + } +} + +// closeProcessInfo closes process and thread handles. +func closeProcessInfo(pi *windows.ProcessInformation) { + if pi != nil { + if err := windows.CloseHandle(pi.Process); err != nil { + log.Debugf("close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Debugf("close thread handle: %v", err) + } + } +} + +// windowsHandleReader wraps a Windows handle for reading. +type windowsHandleReader struct { + handle windows.Handle +} + +func (r *windowsHandleReader) Read(p []byte) (n int, err error) { + var bytesRead uint32 + if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil { + return 0, err + } + return int(bytesRead), nil +} + +func (r *windowsHandleReader) Close() error { + return windows.CloseHandle(r.handle) +} + +// windowsHandleWriter wraps a Windows handle for writing. +type windowsHandleWriter struct { + handle windows.Handle +} + +func (w *windowsHandleWriter) Write(p []byte) (n int, err error) { + var bytesWritten uint32 + if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil { + return 0, err + } + return int(bytesWritten), nil +} + +func (w *windowsHandleWriter) Close() error { + return windows.CloseHandle(w.handle) +} diff --git a/client/ssh/server/winpty/conpty_test.go b/client/ssh/server/winpty/conpty_test.go new file mode 100644 index 000000000..4f04e1fad --- /dev/null +++ b/client/ssh/server/winpty/conpty_test.go @@ -0,0 +1,290 @@ +//go:build windows + +package winpty + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func TestBuildShellArgs(t *testing.T) { + tests := []struct { + name string + shell string + command string + expected []string + }{ + { + name: "Shell with command", + shell: "powershell.exe", + command: "Get-Process", + expected: []string{"powershell.exe", "-Command", "Get-Process"}, + }, + { + name: "CMD with command", + shell: "cmd.exe", + command: "dir", + expected: []string{"cmd.exe", "-Command", "dir"}, + }, + { + name: "Shell interactive", + shell: "powershell.exe", + command: "", + expected: []string{"powershell.exe"}, + }, + { + name: "CMD interactive", + shell: "cmd.exe", + command: "", + expected: []string{"cmd.exe"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildShellArgs(tt.shell, tt.command) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildCommandLine(t *testing.T) { + tests := []struct { + name string + args []string + expected string + }{ + { + name: "Simple args", + args: []string{"cmd.exe", "/c", "echo"}, + expected: "cmd.exe /c echo", + }, + { + name: "Args with spaces", + args: []string{"Program Files\\app.exe", "arg with spaces"}, + expected: `"Program Files\app.exe" "arg with spaces"`, + }, + { + name: "Args with quotes", + args: []string{"cmd.exe", "/c", `echo "hello world"`}, + expected: `cmd.exe /c "echo \"hello world\""`, + }, + { + name: "PowerShell calling PowerShell", + args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`}, + expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`, + }, + { + name: "Complex nested quotes", + args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`}, + expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`, + }, + { + name: "Path with spaces and args", + args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`}, + expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`, + }, + { + name: "Empty argument", + args: []string{"cmd.exe", "/c", "echo", ""}, + expected: `cmd.exe /c echo ""`, + }, + { + name: "Argument with backslashes", + args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"}, + expected: `robocopy C:\Source\ C:\Dest\ /E`, + }, + { + name: "Empty args", + args: []string{}, + expected: "", + }, + { + name: "Single arg with space", + args: []string{"path with spaces"}, + expected: `"path with spaces"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildCommandLine(tt.args) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCreateConPtyPipes(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + + // Verify all handles are valid + assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid") + + // Clean up handles + closeHandles(inputRead, inputWrite, outputRead, outputWrite) +} + +func TestCreateConPty(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + defer closeHandles(inputRead, inputWrite, outputRead, outputWrite) + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + require.NoError(t, err, "Should create ConPty successfully") + assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid") + + // Clean up ConPty + ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty)) + assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully") +} + +func TestConvertEnvironmentToUTF16(t *testing.T) { + tests := []struct { + name string + userEnv []string + hasError bool + }{ + { + name: "Valid environment variables", + userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"}, + hasError: false, + }, + { + name: "Empty environment", + userEnv: []string{}, + hasError: false, + }, + { + name: "Environment with empty strings", + userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"}, + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertEnvironmentToUTF16(tt.userEnv) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if len(tt.userEnv) == 0 { + assert.Nil(t, result, "Empty environment should return nil") + } else { + assert.NotNil(t, result, "Non-empty environment should return valid pointer") + } + } + }) + } +} + +func TestDuplicateToPrimaryToken(t *testing.T) { + if testing.Short() { + t.Skip("Skipping token tests in short mode") + } + + // Get current process token for testing + var token windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token) + require.NoError(t, err, "Should open current process token") + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + t.Logf("Failed to close token: %v", err) + } + }() + + primaryToken, err := duplicateToPrimaryToken(windows.Handle(token)) + require.NoError(t, err, "Should duplicate token to primary") + assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid") + + // Clean up + err = windows.CloseHandle(primaryToken) + assert.NoError(t, err, "Should close primary token") +} + +func TestWindowsHandleReader(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Write test data + testData := []byte("Hello, Windows Handle Reader!") + var bytesWritten uint32 + err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil) + require.NoError(t, err, "Should write test data") + require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data") + + // Close write handle to signal EOF + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + writeHandle = windows.InvalidHandle + + // Test reading + reader := &windowsHandleReader{handle: readHandle} + buffer := make([]byte, len(testData)) + n, err := reader.Read(buffer) + require.NoError(t, err, "Should read from handle") + assert.Equal(t, len(testData), n, "Should read expected number of bytes") + assert.Equal(t, testData, buffer, "Should read expected data") +} + +func TestWindowsHandleWriter(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Test writing + testData := []byte("Hello, Windows Handle Writer!") + writer := &windowsHandleWriter{handle: writeHandle} + n, err := writer.Write(testData) + require.NoError(t, err, "Should write to handle") + assert.Equal(t, len(testData), n, "Should write expected number of bytes") + + // Close write handle + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + + // Verify data was written by reading it back + buffer := make([]byte, len(testData)) + var bytesRead uint32 + err = windows.ReadFile(readHandle, buffer, &bytesRead, nil) + require.NoError(t, err, "Should read back written data") + assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes") + assert.Equal(t, testData, buffer, "Should read back expected data") +} + +// BenchmarkConPtyCreation benchmarks ConPty creation performance +func BenchmarkConPtyCreation(b *testing.B) { + for i := 0; i < b.N; i++ { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + b.Fatal(err) + } + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + if err != nil { + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + b.Fatal(err) + } + + // Clean up + if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 { + log.Debugf("ClosePseudoConsole failed: %v", err) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + } +} diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go deleted file mode 100644 index 76f43fd4e..000000000 --- a/client/ssh/server_mock.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build !js - -package ssh - -import "context" - -// MockServer mocks ssh.Server -type MockServer struct { - Ctx context.Context - StopFunc func() error - StartFunc func() error - AddAuthorizedKeyFunc func(peer, newKey string) error - RemoveAuthorizedKeyFunc func(peer string) -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *MockServer) RemoveAuthorizedKey(peer string) { - if srv.RemoveAuthorizedKeyFunc == nil { - return - } - srv.RemoveAuthorizedKeyFunc(peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error { - if srv.AddAuthorizedKeyFunc == nil { - return nil - } - return srv.AddAuthorizedKeyFunc(peer, newKey) -} - -// Stop stops SSH server. -func (srv *MockServer) Stop() error { - if srv.StopFunc == nil { - return nil - } - return srv.StopFunc() -} - -// Start starts SSH server. Blocking -func (srv *MockServer) Start() error { - if srv.StartFunc == nil { - return nil - } - return srv.StartFunc() -} diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go deleted file mode 100644 index 1f310c2bb..000000000 --- a/client/ssh/server_test.go +++ /dev/null @@ -1,123 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/ssh" - "strings" - "testing" -) - -func TestServer_AddAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - // add multiple keys - keys := map[string][]byte{} - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys[peer] = remotePubKey - } - - // make sure that all keys have been added - for peer, remotePubKey := range keys { - k, ok := server.authorizedKeys[peer] - assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys") - - assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k)))) - } - -} - -func TestServer_RemoveAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey("remotePeer", string(remotePubKey)) - if err != nil { - t.Error(err) - } - - server.RemoveAuthorizedKey("remotePeer") - - _, ok := server.authorizedKeys["remotePeer"] - assert.False(t, ok, "expecting remotePeer's SSH key to be removed") -} - -func TestServer_PubKeyHandler(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - var keys []ssh.PublicKey - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys = append(keys, remoteParsedPubKey) - } - - for _, key := range keys { - accepted := server.publicKeyHandler(nil, key) - - assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key))) - } - -} diff --git a/client/ssh/util.go b/client/ssh/ssh.go similarity index 86% rename from client/ssh/util.go rename to client/ssh/ssh.go index a54a609bc..c0024c599 100644 --- a/client/ssh/util.go +++ b/client/ssh/ssh.go @@ -32,9 +32,8 @@ const RSA KeyType = "rsa" // RSAKeySize is a size of newly generated RSA key const RSAKeySize = 2048 -// GeneratePrivateKey creates RSA Private Key of specified byte size +// GeneratePrivateKey creates a private key of the specified type. func GeneratePrivateKey(keyType KeyType) ([]byte, error) { - var key crypto.Signer var err error switch keyType { @@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) { return pemBytes, nil } -// GeneratePublicKey returns the public part of the private key +// GeneratePublicKey returns the public part of the private key. func GeneratePublicKey(key []byte) ([]byte, error) { signer, err := gossh.ParsePrivateKey(key) if err != nil { @@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) { return []byte(strKey), nil } -// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format +// EncodePrivateKeyToPEM encodes a private key to PEM format. func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) { mk, err := x509.MarshalPKCS8PrivateKey(privateKey) if err != nil { return nil, err } - // pem.Block privBlock := pem.Block{ Type: "PRIVATE KEY", Bytes: mk, } - - // Private key in PEM format privatePEM := pem.EncodeToMemory(&privBlock) return privatePEM, nil } diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go new file mode 100644 index 000000000..8960d8dd0 --- /dev/null +++ b/client/ssh/testutil/user_helpers.go @@ -0,0 +1,173 @@ +package testutil + +import ( + "fmt" + "log" + "os" + "os/exec" + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +var testCreatedUsers = make(map[string]bool) +var testUsersToCleanup []string + +// GetTestUsername returns an appropriate username for testing +func GetTestUsername(t *testing.T) string { + if runtime.GOOS == "windows" { + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + if IsSystemAccount(currentUser.Username) { + if IsCI() { + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } else { + if _, err := user.Lookup("Administrator"); err == nil { + return "Administrator" + } + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } + } + return currentUser.Username + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + return currentUser.Username +} + +// IsCI checks if we're running in a CI environment +func IsCI() bool { + if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" { + return true + } + + hostname, err := os.Hostname() + if err == nil && strings.HasPrefix(hostname, "runner") { + return true + } + + return false +} + +// IsSystemAccount checks if the user is a system account that can't authenticate +func IsSystemAccount(username string) bool { + systemAccounts := []string{ + "system", + "NT AUTHORITY\\SYSTEM", + "NT AUTHORITY\\LOCAL SERVICE", + "NT AUTHORITY\\NETWORK SERVICE", + } + + for _, sysAccount := range systemAccounts { + if strings.EqualFold(username, sysAccount) { + return true + } + } + + return strings.HasSuffix(username, "$") +} + +// RegisterTestUserCleanup registers a test user for cleanup +func RegisterTestUserCleanup(username string) { + if !testCreatedUsers[username] { + testCreatedUsers[username] = true + testUsersToCleanup = append(testUsersToCleanup, username) + } +} + +// CleanupTestUsers removes all created test users +func CleanupTestUsers() { + for _, username := range testUsersToCleanup { + RemoveWindowsTestUser(username) + } + testUsersToCleanup = nil + testCreatedUsers = make(map[string]bool) +} + +// GetOrCreateTestUser creates a test user on Windows if needed +func GetOrCreateTestUser(t *testing.T) string { + testUsername := "netbird-test-user" + + if _, err := user.Lookup(testUsername); err == nil { + return testUsername + } + + if CreateWindowsTestUser(t, testUsername) { + RegisterTestUserCleanup(testUsername) + return testUsername + } + + return "" +} + +// RemoveWindowsTestUser removes a local user on Windows using PowerShell +func RemoveWindowsTestUser(username string) { + if runtime.GOOS != "windows" { + return + } + + psCmd := fmt.Sprintf(` + try { + Remove-LocalUser -Name "%s" -ErrorAction Stop + Write-Output "User removed successfully" + } catch { + if ($_.Exception.Message -like "*cannot be found*") { + Write-Output "User not found (already removed)" + } else { + Write-Error $_.Exception.Message + } + } + `, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output)) + } else { + log.Printf("Test user %s cleanup result: %s", username, string(output)) + } +} + +// CreateWindowsTestUser creates a local user on Windows using PowerShell +func CreateWindowsTestUser(t *testing.T, username string) bool { + if runtime.GOOS != "windows" { + return false + } + + psCmd := fmt.Sprintf(` + try { + $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force + New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires + Add-LocalGroupMember -Group "Users" -Member "%s" + Write-Output "User created successfully" + } catch { + if ($_.Exception.Message -like "*already exists*") { + Write-Output "User already exists" + } else { + Write-Error $_.Exception.Message + exit 1 + } + } + `, username, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("Failed to create test user: %v, output: %s", err, string(output)) + return false + } + + t.Logf("Test user creation result: %s", string(output)) + return true +} diff --git a/client/ssh/testutil/user_helpers_test.go b/client/ssh/testutil/user_helpers_test.go new file mode 100644 index 000000000..db2f5f06d --- /dev/null +++ b/client/ssh/testutil/user_helpers_test.go @@ -0,0 +1,115 @@ +package testutil + +import ( + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUserCurrentBehavior validates user.Current() behavior on Windows. +// When running as SYSTEM on a domain-joined machine, user.Current() returns: +// - Username: Computer account name (e.g., "DOMAIN\MACHINE$") +// - SID: SYSTEM SID (S-1-5-18) +func TestUserCurrentBehavior(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid) + + // When running as SYSTEM, validate expected behavior + if currentUser.Uid == "S-1-5-18" { + t.Run("SYSTEM_account_behavior", func(t *testing.T) { + // SID must be S-1-5-18 for SYSTEM + require.Equal(t, "S-1-5-18", currentUser.Uid, + "SYSTEM account must have SID S-1-5-18") + + // Username can be either "NT AUTHORITY\SYSTEM" (standalone) + // or "DOMAIN\MACHINE$" (domain-joined) + username := currentUser.Username + isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY") + isComputerAccount := strings.HasSuffix(username, "$") + + assert.True(t, isNTAuthority || isComputerAccount, + "Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s", + username) + + if isComputerAccount { + t.Logf("SYSTEM as computer account: %s", username) + } else if isNTAuthority { + t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM") + } + }) + } + + // Validate that IsSystemAccount correctly identifies system accounts + t.Run("IsSystemAccount_validation", func(t *testing.T) { + // Test with current user if it's a system account + if currentUser.Uid == "S-1-5-18" || // SYSTEM + currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE + currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE + + result := IsSystemAccount(currentUser.Username) + assert.True(t, result, + "IsSystemAccount should recognize system account: %s (SID: %s)", + currentUser.Username, currentUser.Uid) + } + + // Test explicit cases + testCases := []struct { + username string + expected bool + reason string + }{ + {"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"}, + {"system", true, "system"}, + {"SYSTEM", true, "SYSTEM (case insensitive)"}, + {"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"}, + {"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"}, + {"DOMAIN\\MACHINE$", true, "computer account (ends with $)"}, + {"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"}, + {"Administrator", false, "Administrator is not a system account"}, + {"alice", false, "regular user"}, + {"DOMAIN\\alice", false, "domain user"}, + } + + for _, tc := range testCases { + t.Run(tc.username, func(t *testing.T) { + result := IsSystemAccount(tc.username) + assert.Equal(t, tc.expected, result, + "IsSystemAccount(%q) should be %v because: %s", + tc.username, tc.expected, tc.reason) + }) + } + }) +} + +// TestComputerAccountDetection validates computer account detection. +func TestComputerAccountDetection(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + computerAccounts := []string{ + "MACHINE$", + "WIN2K19-C2$", + "DOMAIN\\MACHINE$", + "WORKGROUP\\SERVER$", + "server.domain.com$", + } + + for _, account := range computerAccounts { + t.Run(account, func(t *testing.T) { + result := IsSystemAccount(account) + assert.True(t, result, + "Computer account %q should be recognized as system account", account) + }) + } +} diff --git a/client/ssh/window_freebsd.go b/client/ssh/window_freebsd.go deleted file mode 100644 index ef4848341..000000000 --- a/client/ssh/window_freebsd.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build freebsd - -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { -} diff --git a/client/ssh/window_unix.go b/client/ssh/window_unix.go deleted file mode 100644 index 2891eb70e..000000000 --- a/client/ssh/window_unix.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux || darwin - -package ssh - -import ( - "os" - "syscall" - "unsafe" -) - -func setWinSize(file *os.File, width, height int) { - syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint - uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0}))) -} diff --git a/client/ssh/window_windows.go b/client/ssh/window_windows.go deleted file mode 100644 index 5abd41f27..000000000 --- a/client/ssh/window_windows.go +++ /dev/null @@ -1,9 +0,0 @@ -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { - -} diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..d975f0e29 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -80,6 +81,18 @@ type NsServerGroupStateOutput struct { Error string `json:"error" yaml:"error"` } +type SSHSessionOutput struct { + Username string `json:"username" yaml:"username"` + RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` + Command string `json:"command" yaml:"command"` + JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` +} + +type SSHServerStateOutput struct { + Enabled bool `json:"enabled" yaml:"enabled"` + Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"` +} + type OutputOverview struct { Peers PeersStateOutput `json:"peers" yaml:"peers"` CliVersion string `json:"cliVersion" yaml:"cliVersion"` @@ -99,6 +112,7 @@ type OutputOverview struct { Events []SystemEventOutput `json:"events" yaml:"events"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` ProfileName string `json:"profileName" yaml:"profileName"` + SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"` } func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { @@ -120,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status relayOverview := mapRelays(pbFullStatus.GetRelays()) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) overview := OutputOverview{ Peers: peersOverview, @@ -140,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), ProfileName: profName, + SSHServerState: sshServerOverview, } if anon { @@ -189,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput { return mappedNSGroups } +func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput { + if sshServerState == nil { + return SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + } + } + + sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions())) + for _, session := range sshServerState.GetSessions() { + sessions = append(sessions, SSHSessionOutput{ + Username: session.GetUsername(), + RemoteAddress: session.GetRemoteAddress(), + Command: session.GetCommand(), + JWTUsername: session.GetJwtUsername(), + }) + } + + return SSHServerStateOutput{ + Enabled: sshServerState.GetEnabled(), + Sessions: sessions, + } +} + func mapPeers( peers []*proto.PeerState, statusFilter string, @@ -205,15 +245,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { @@ -296,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) { return string(yamlBytes), nil } -func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string { +func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { var managementConnString string if overview.ManagementState.Connected { managementConnString = "Connected" @@ -337,10 +380,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { @@ -395,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, lazyConnectionEnabledStatus = "true" } + sshServerStatus := "Disabled" + if overview.SSHServerState.Enabled { + sessionCount := len(overview.SSHServerState.Sessions) + if sessionCount > 0 { + sessionWord := "session" + if sessionCount > 1 { + sessionWord = "sessions" + } + sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord) + } else { + sshServerStatus = "Enabled" + } + + if showSSHSessions && sessionCount > 0 { + for _, session := range overview.SSHServerState.Sessions { + var sessionDisplay string + if session.JWTUsername != "" { + sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", + session.JWTUsername, + session.RemoteAddress, + session.Username, + session.Command, + ) + } else { + sessionDisplay = fmt.Sprintf("[%s@%s] %s", + session.Username, + session.RemoteAddress, + session.Command, + ) + } + sshServerStatus += "\n " + sessionDisplay + } + } + } + peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) goos := runtime.GOOS @@ -418,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Lazy connection: %s\n"+ + "SSH Server: %s\n"+ "Networks: %s\n"+ "Forwarding rules: %d\n"+ "Peers count: %s\n", @@ -434,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, + sshServerStatus, networks, overview.NumberOfForwardingRules, peersCountString, @@ -444,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, func ParseToFullDetailSummary(overview OutputOverview) string { parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) parsedEventsString := parseEvents(overview.Events) - summary := ParseGeneralSummary(overview, true, true, true) + summary := ParseGeneralSummary(overview, true, true, true, true) return fmt.Sprintf( "Peers detail:"+ @@ -736,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { event.Metadata[k] = a.AnonymizeString(v) } } + + for i, session := range overview.SSHServerState.Sessions { + if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil { + overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port) + } else { + overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress) + } + overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command) + } } diff --git a/client/status/status_test.go b/client/status/status_test.go index 660efd9ef..1dca1e5b1 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -231,6 +231,10 @@ var overview = OutputOverview{ Networks: []string{ "10.10.0.0/24", }, + SSHServerState: SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) { ], "events": [], "lazyConnectionEnabled": false, - "profileName":"" + "profileName":"", + "sshServer":{ + "enabled":false, + "sessions":[] + } }` // @formatter:on @@ -488,6 +496,9 @@ dnsServers: events: [] lazyConnectionEnabled: false profileName: "" +sshServer: + enabled: false + sessions: [] ` assert.Equal(t, expectedYAML, yaml) @@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected @@ -563,7 +575,7 @@ Peers count: 2/2 Connected } func TestParsingToShortVersion(t *testing.T) { - shortVersion := ParseGeneralSummary(overview, false, false, false) + shortVersion := ParseGeneralSummary(overview, false, false, false, false) expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` Daemon version: 0.14.1 @@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected diff --git a/client/system/info.go b/client/system/info.go index a180be4c0..01176e765 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -72,6 +72,12 @@ type Info struct { BlockInbound bool LazyConnectionEnabled bool + + EnableSSHRoot bool + EnableSSHSFTP bool + EnableSSHLocalPortForwarding bool + EnableSSHRemotePortForwarding bool + DisableSSHAuth bool } func (i *Info) SetFlags( @@ -79,6 +85,8 @@ func (i *Info) SetFlags( serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, + enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, + disableSSHAuth *bool, ) { i.RosenpassEnabled = rosenpassEnabled i.RosenpassPermissive = rosenpassPermissive @@ -94,6 +102,22 @@ func (i *Info) SetFlags( i.BlockInbound = blockInbound i.LazyConnectionEnabled = lazyConnectionEnabled + + if enableSSHRoot != nil { + i.EnableSSHRoot = *enableSSHRoot + } + if enableSSHSFTP != nil { + i.EnableSSHSFTP = *enableSSHSFTP + } + if enableSSHLocalPortForwarding != nil { + i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding + } + if enableSSHRemotePortForwarding != nil { + i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding + } + if disableSSHAuth != nil { + i.DisableSSHAuth = *disableSSHAuth + } } // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context diff --git a/client/ui/assets/netbird-disconnected.ico b/client/ui/assets/netbird-disconnected.ico new file mode 100644 index 000000000..812e9d283 Binary files /dev/null and b/client/ui/assets/netbird-disconnected.ico differ diff --git a/client/ui/assets/netbird-disconnected.png b/client/ui/assets/netbird-disconnected.png new file mode 100644 index 000000000..79d4775ea Binary files /dev/null and b/client/ui/assets/netbird-disconnected.png differ diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7c2000a9d..87bac8c31 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,19 +31,19 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/client/ui/process" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -56,6 +56,7 @@ const ( const ( censoredPreSharedKey = "**********" + maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds ) func main() { @@ -86,21 +87,24 @@ func main() { // Create the service client (this also builds the settings or networks UI if requested). client := newServiceClient(&newServiceClientArgs{ - addr: flags.daemonAddr, - logFile: logFile, - app: a, - showSettings: flags.showSettings, - showNetworks: flags.showNetworks, - showLoginURL: flags.showLoginURL, - showDebug: flags.showDebug, - showProfiles: flags.showProfiles, + addr: flags.daemonAddr, + logFile: logFile, + app: a, + showSettings: flags.showSettings, + showNetworks: flags.showNetworks, + showLoginURL: flags.showLoginURL, + showDebug: flags.showDebug, + showProfiles: flags.showProfiles, + showQuickActions: flags.showQuickActions, + showUpdate: flags.showUpdate, + showUpdateVersion: flags.showUpdateVersion, }) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles { + if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate { a.Run() return } @@ -112,23 +116,31 @@ func main() { return } if running { - log.Warnf("another process is running with pid %d, exiting", pid) + log.Infof("another process is running with pid %d, sending signal to show window", pid) + if err := sendShowWindowSignal(pid); err != nil { + log.Errorf("send signal to running instance: %v", err) + } return } + client.setupSignalHandler(client.ctx) + client.setDefaultFonts() systray.Run(client.onTrayReady, client.onTrayExit) } type cliFlags struct { - daemonAddr string - showSettings bool - showNetworks bool - showProfiles bool - showDebug bool - showLoginURL bool - errorMsg string - saveLogsInFile bool + daemonAddr string + showSettings bool + showNetworks bool + showProfiles bool + showDebug bool + showLoginURL bool + showQuickActions bool + errorMsg string + saveLogsInFile bool + showUpdate bool + showUpdateVersion string } // parseFlags reads and returns all needed command-line flags. @@ -144,9 +156,12 @@ func parseFlags() *cliFlags { flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window") flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window") flag.BoolVar(&flags.showDebug, "debug", false, "run debug window") + flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window") flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window") flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window") + flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window") + flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to") flag.Parse() return &flags } @@ -159,11 +174,9 @@ func initLogFile() (string, error) { // watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon. func watchSettingsChanges(a fyne.App, client *serviceClient) { - settingsChangeChan := make(chan fyne.Settings) - a.Settings().AddChangeListener(settingsChangeChan) - for range settingsChangeChan { + a.Settings().AddListener(func(settings fyne.Settings) { client.updateIcon() - } + }) } // showErrorMessage displays an error message in a simple window. @@ -203,10 +216,11 @@ var iconConnectedDot []byte var iconDisconnectedDot []byte type serviceClient struct { - ctx context.Context - cancel context.CancelFunc - addr string - conn proto.DaemonServiceClient + ctx context.Context + cancel context.CancelFunc + addr string + conn proto.DaemonServiceClient + connLock sync.Mutex eventHandler *eventHandler @@ -260,34 +274,50 @@ type serviceClient struct { iMTU *widget.Entry // switch elements for settings form - sRosenpassPermissive *widget.Check - sNetworkMonitor *widget.Check - sDisableDNS *widget.Check - sDisableClientRoutes *widget.Check - sDisableServerRoutes *widget.Check - sBlockLANAccess *widget.Check + sRosenpassPermissive *widget.Check + sNetworkMonitor *widget.Check + sDisableDNS *widget.Check + sDisableClientRoutes *widget.Check + sDisableServerRoutes *widget.Check + sBlockLANAccess *widget.Check + sEnableSSHRoot *widget.Check + sEnableSSHSFTP *widget.Check + sEnableSSHLocalPortForward *widget.Check + sEnableSSHRemotePortForward *widget.Check + sDisableSSHAuth *widget.Check + iSSHJWTCacheTTL *widget.Entry // observable settings over corresponding iMngURL and iPreSharedKey values. - managementURL string - preSharedKey string - RosenpassPermissive bool - interfaceName string - interfacePort int - mtu uint16 - networkMonitor bool - disableDNS bool - disableClientRoutes bool - disableServerRoutes bool - blockLANAccess bool + managementURL string + preSharedKey string + + RosenpassPermissive bool + interfaceName string + interfacePort int + mtu uint16 + networkMonitor bool + disableDNS bool + disableClientRoutes bool + disableServerRoutes bool + blockLANAccess bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int connected bool update *version.Update daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + settingsEnabled bool + profilesEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window + wQuickActions fyne.Window eventManager *event.Manager @@ -297,6 +327,10 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + wUpdateProgress fyne.Window + updateContextCancel context.CancelFunc + + connectCancel context.CancelFunc } type menuHandler struct { @@ -305,14 +339,17 @@ type menuHandler struct { } type newServiceClientArgs struct { - addr string - logFile string - app fyne.App - showSettings bool - showNetworks bool - showDebug bool - showLoginURL bool - showProfiles bool + addr string + logFile string + app fyne.App + showSettings bool + showNetworks bool + showDebug bool + showLoginURL bool + showProfiles bool + showQuickActions bool + showUpdate bool + showUpdateVersion string } // newServiceClient instance constructor @@ -330,7 +367,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdate("nb/client-ui"), + update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -348,6 +385,10 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { s.showDebugUI() case args.showProfiles: s.showProfilesUI() + case args.showQuickActions: + s.showQuickActionsUI() + case args.showUpdate: + s.showUpdateProgress(ctx, args.showUpdateVersion) } return s @@ -424,18 +465,22 @@ func (s *serviceClient) showSettingsUI() { s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil) s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil) s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil) + s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil) + s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) + s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil) + s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil) + s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil) + s.iSSHJWTCacheTTL = widget.NewEntry() s.wSettings.SetContent(s.getSettingsForm()) - s.wSettings.Resize(fyne.NewSize(600, 500)) + s.wSettings.Resize(fyne.NewSize(600, 400)) s.wSettings.SetFixedSize(true) s.getSrvConfig() s.wSettings.Show() } -// getSettingsForm to embed it into settings window. -func (s *serviceClient) getSettingsForm() *widget.Form { - +func (s *serviceClient) getConnectionForm() *widget.Form { var activeProfName string activeProf, err := s.profileManager.GetActiveProfile() if err != nil { @@ -446,164 +491,286 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return &widget.Form{ Items: []*widget.FormItem{ {Text: "Profile", Widget: widget.NewLabel(activeProfName)}, + {Text: "Management URL", Widget: s.iMngURL}, + {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "MTU", Widget: s.iMTU}, - {Text: "Management URL", Widget: s.iMngURL}, - {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Log File", Widget: s.iLogFile}, + }, + } +} + +func (s *serviceClient) saveSettings() { + // Check if update settings are disabled by daemon + features, err := s.getFeatures() + if err != nil { + log.Errorf("failed to get features from daemon: %v", err) + // Continue with default behavior if features can't be retrieved + } else if features != nil && features.DisableUpdateSettings { + log.Warn("Configuration updates are disabled by daemon") + dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) + return + } + + if err := s.validateSettings(); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + port, mtu, err := s.parseNumericSettings() + if err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + iMngURL := strings.TrimSpace(s.iMngURL.Text) + + if s.hasSettingsChanged(iMngURL, port, mtu) { + if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + } + + s.wSettings.Close() +} + +func (s *serviceClient) validateSettings() error { + if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { + if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { + return fmt.Errorf("Invalid Pre-shared Key Value") + } + } + return nil +} + +func (s *serviceClient) parseNumericSettings() (int64, int64, error) { + port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid interface port") + } + if port < 1 || port > 65535 { + return 0, 0, errors.New("Invalid interface port: out of range 1-65535") + } + + var mtu int64 + mtuText := strings.TrimSpace(s.iMTU.Text) + if mtuText != "" { + mtu, err = strconv.ParseInt(mtuText, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid MTU value") + } + if mtu < iface.MinMTU || mtu > iface.MaxMTU { + return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU) + } + } + + return port, mtu, nil +} + +func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool { + return s.managementURL != iMngURL || + s.preSharedKey != s.iPreSharedKey.Text || + s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.interfaceName != s.iInterfaceName.Text || + s.interfacePort != int(port) || + s.mtu != uint16(mtu) || + s.networkMonitor != s.sNetworkMonitor.Checked || + s.disableDNS != s.sDisableDNS.Checked || + s.disableClientRoutes != s.sDisableClientRoutes.Checked || + s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.blockLANAccess != s.sBlockLANAccess.Checked || + s.hasSSHChanges() +} + +func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error { + s.managementURL = iMngURL + s.preSharedKey = s.iPreSharedKey.Text + s.mtu = uint16(mtu) + + req, err := s.buildSetConfigRequest(iMngURL, port, mtu) + if err != nil { + return fmt.Errorf("build config request: %w", err) + } + + if err := s.sendConfigUpdate(req); err != nil { + return fmt.Errorf("set configuration: %w", err) + } + + return nil +} + +func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) { + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + return nil, fmt.Errorf("get active profile: %w", err) + } + + req := &proto.SetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + } + + if iMngURL != "" { + req.ManagementUrl = iMngURL + } + + req.RosenpassPermissive = &s.sRosenpassPermissive.Checked + req.InterfaceName = &s.iInterfaceName.Text + req.WireguardPort = &port + if mtu > 0 { + req.Mtu = &mtu + } + + req.NetworkMonitor = &s.sNetworkMonitor.Checked + req.DisableDns = &s.sDisableDNS.Checked + req.DisableClientRoutes = &s.sDisableClientRoutes.Checked + req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.BlockLanAccess = &s.sBlockLANAccess.Checked + + req.EnableSSHRoot = &s.sEnableSSHRoot.Checked + req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked + req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked + req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked + req.DisableSSHAuth = &s.sDisableSSHAuth.Checked + + sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text) + if sshJWTCacheTTLText != "" { + sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32) + if err != nil { + return nil, errors.New("Invalid SSH JWT Cache TTL value") + } + if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL { + return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL) + } + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + + if s.iPreSharedKey.Text != censoredPreSharedKey { + req.OptionalPreSharedKey = &s.iPreSharedKey.Text + } + + return req, nil +} + +func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return fmt.Errorf("get client: %w", err) + } + + _, err = conn.SetConfig(s.ctx, req) + if err != nil { + return fmt.Errorf("set config: %w", err) + } + + // Reconnect if connected to apply the new settings + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + } + }() + + return nil +} + +func (s *serviceClient) getSettingsForm() fyne.CanvasObject { + connectionForm := s.getConnectionForm() + networkForm := s.getNetworkForm() + sshForm := s.getSSHForm() + tabs := container.NewAppTabs( + container.NewTabItem("Connection", connectionForm), + container.NewTabItem("Network", networkForm), + container.NewTabItem("SSH", sshForm), + ) + saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings) + saveButton.Importance = widget.HighImportance + cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() { + s.wSettings.Close() + }) + buttonContainer := container.NewHBox( + layout.NewSpacer(), + cancelButton, + saveButton, + ) + return container.NewBorder(nil, buttonContainer, nil, nil, tabs) +} + +func (s *serviceClient) getNetworkForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ {Text: "Network Monitor", Widget: s.sNetworkMonitor}, {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, - SubmitText: "Save", - OnSubmit: func() { - // Check if update settings are disabled by daemon - features, err := s.getFeatures() - if err != nil { - log.Errorf("failed to get features from daemon: %v", err) - // Continue with default behavior if features can't be retrieved - } else if features != nil && features.DisableUpdateSettings { - log.Warn("Configuration updates are disabled by daemon") - dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) - return - } + } +} - if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { - // validate preSharedKey if it added - if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { - dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings) - return - } - } - - port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) - return - } - - var mtu int64 - mtuText := strings.TrimSpace(s.iMTU.Text) - if mtuText != "" { - var err error - mtu, err = strconv.ParseInt(mtuText, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings) - return - } - if mtu < iface.MinMTU || mtu > iface.MaxMTU { - dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings) - return - } - } - - iMngURL := strings.TrimSpace(s.iMngURL.Text) - - defer s.wSettings.Close() - - // Check if any settings have changed - if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || - s.RosenpassPermissive != s.sRosenpassPermissive.Checked || - s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || - s.mtu != uint16(mtu) || - s.networkMonitor != s.sNetworkMonitor.Checked || - s.disableDNS != s.sDisableDNS.Checked || - s.disableClientRoutes != s.sDisableClientRoutes.Checked || - s.disableServerRoutes != s.sDisableServerRoutes.Checked || - s.blockLANAccess != s.sBlockLANAccess.Checked { - - s.managementURL = iMngURL - s.preSharedKey = s.iPreSharedKey.Text - s.mtu = uint16(mtu) - - currUser, err := user.Current() - if err != nil { - log.Errorf("get current user: %v", err) - return - } - - var req proto.SetConfigRequest - req.ProfileName = activeProf.Name - req.Username = currUser.Username - - if iMngURL != "" { - req.ManagementUrl = iMngURL - } - - req.RosenpassPermissive = &s.sRosenpassPermissive.Checked - req.InterfaceName = &s.iInterfaceName.Text - req.WireguardPort = &port - if mtu > 0 { - req.Mtu = &mtu - } - req.NetworkMonitor = &s.sNetworkMonitor.Checked - req.DisableDns = &s.sDisableDNS.Checked - req.DisableClientRoutes = &s.sDisableClientRoutes.Checked - req.DisableServerRoutes = &s.sDisableServerRoutes.Checked - req.BlockLanAccess = &s.sBlockLANAccess.Checked - - if s.iPreSharedKey.Text != censoredPreSharedKey { - req.OptionalPreSharedKey = &s.iPreSharedKey.Text - } - - conn, err := s.getSrvClient(failFastTimeout) - if err != nil { - log.Errorf("get client: %v", err) - dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings) - return - } - _, err = conn.SetConfig(s.ctx, &req) - if err != nil { - log.Errorf("set config: %v", err) - dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings) - return - } - - go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) - return - } - } - }() - } - }, - OnCancel: func() { - s.wSettings.Close() +func (s *serviceClient) getSSHForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ + {Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot}, + {Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP}, + {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward}, + {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward}, + {Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth}, + {Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL}, }, } } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) hasSSHChanges() bool { + currentSSHJWTCacheTTL := s.sshJWTCacheTTL + if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" { + val, err := strconv.Atoi(text) + if err != nil { + return true + } + currentSSHJWTCacheTTL = val + } + + return s.enableSSHRoot != s.sEnableSSHRoot.Checked || + s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || + s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || + s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked || + s.disableSSHAuth != s.sDisableSSHAuth.Checked || + s.sshJWTCacheTTL != currentSSHJWTCacheTTL +} + +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -611,84 +778,82 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginReq := &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, - }) + } + + profileState, err := s.profileManager.GetProfileState(activeProf.Name) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginReq.Hint = &profileState.Email + } + + loginResp, err := conn.Login(ctx, loginReq) + if err != nil { + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ + AutoUpdate: protobuf.Bool(wannaAutoUpdate), + }); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -698,24 +863,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -748,7 +909,7 @@ func (s *serviceClient) updateStatus() error { var systrayIconState bool switch { - case status.Status == string(internal.StatusConnected): + case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled(): s.connected = true s.sendNotification = true if s.isUpdateIconActive { @@ -762,6 +923,7 @@ func (s *serviceClient) updateStatus() error { s.mUp.Disable() s.mDown.Enable() s.mNetworks.Enable() + s.mExitNode.Enable() go s.updateExitNodes() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -851,6 +1013,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, @@ -951,9 +1114,32 @@ func (s *serviceClient) onTrayReady() { s.updateExitNodes() } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + // todo use new Category + if windowAction, ok := event.Metadata["progress_window"]; ok { + targetVersion, ok := event.Metadata["version"] + if !ok { + targetVersion = "unknown" + } + log.Debugf("window action: %v", windowAction) + if windowAction == "show" { + if s.updateContextCancel != nil { + s.updateContextCancel() + s.updateContextCancel = nil + } + + subCtx, cancel := context.WithCancel(s.ctx) + go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion) + s.updateContextCancel = cancel + } + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) + + // Start sleep detection listener + go s.startSleepListener() } func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { @@ -990,6 +1176,8 @@ func (s *serviceClient) onTrayExit() { // getSrvClient connection to the service. func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) { + s.connLock.Lock() + defer s.connLock.Unlock() if s.conn != nil { return s.conn, nil } @@ -1012,6 +1200,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } +// startSleepListener initializes the sleep detection service and listens for sleep events +func (s *serviceClient) startSleepListener() { + sleepService, err := sleep.New() + if err != nil { + log.Warnf("%v", err) + return + } + + if err := sleepService.Register(s.handleSleepEvents); err != nil { + log.Errorf("failed to start sleep detection: %v", err) + return + } + + log.Info("sleep detection service initialized") + + // Cleanup on context cancellation + go func() { + <-s.ctx.Done() + log.Info("stopping sleep event listener") + if err := sleepService.Deregister(); err != nil { + log.Errorf("failed to deregister sleep detection: %v", err) + } + }() +} + +// handleSleepEvents sends a sleep notification to the daemon via gRPC +func (s *serviceClient) handleSleepEvents(event sleep.EventType) { + conn, err := s.getSrvClient(0) + if err != nil { + log.Errorf("failed to get daemon client for sleep notification: %v", err) + return + } + + req := &proto.OSLifecycleRequest{} + + switch event { + case sleep.EventTypeWakeUp: + log.Infof("handle wakeup event: %v", event) + req.Type = proto.OSLifecycleRequest_WAKEUP + case sleep.EventTypeSleep: + log.Infof("handle sleep event: %v", event) + req.Type = proto.OSLifecycleRequest_SLEEP + default: + log.Infof("unknown event: %v", event) + return + } + + _, err = conn.NotifyOSLifecycle(s.ctx, req) + if err != nil { + log.Errorf("failed to notify daemon about os lifecycle notification: %v", err) + return + } + + log.Info("successfully notified daemon about os lifecycle") +} + // setSettingsEnabled enables or disables the settings menu based on the provided state func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { @@ -1033,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() { return } + s.updateIndicationLock.Lock() + defer s.updateIndicationLock.Unlock() + // Update settings menu based on current features - if features != nil && features.DisableUpdateSettings { - s.setSettingsEnabled(false) - } else { - s.setSettingsEnabled(true) + settingsEnabled := features == nil || !features.DisableUpdateSettings + if s.settingsEnabled != settingsEnabled { + s.settingsEnabled = settingsEnabled + s.setSettingsEnabled(settingsEnabled) } // Update profile menu based on current features if s.mProfile != nil { - if features != nil && features.DisableProfiles { - s.mProfile.setEnabled(false) - } else { - s.mProfile.setEnabled(true) + profilesEnabled := features == nil || !features.DisableProfiles + if s.profilesEnabled != profilesEnabled { + s.profilesEnabled = profilesEnabled + s.mProfile.setEnabled(profilesEnabled) } } } @@ -1121,6 +1368,25 @@ func (s *serviceClient) getSrvConfig() { s.disableServerRoutes = cfg.DisableServerRoutes s.blockLANAccess = cfg.BlockLANAccess + if cfg.EnableSSHRoot != nil { + s.enableSSHRoot = *cfg.EnableSSHRoot + } + if cfg.EnableSSHSFTP != nil { + s.enableSSHSFTP = *cfg.EnableSSHSFTP + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding + } + if cfg.DisableSSHAuth != nil { + s.disableSSHAuth = *cfg.DisableSSHAuth + } + if cfg.SSHJWTCacheTTL != nil { + s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL + } + if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) s.iPreSharedKey.SetText(cfg.PreSharedKey) @@ -1141,6 +1407,24 @@ func (s *serviceClient) getSrvConfig() { s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) + if cfg.EnableSSHRoot != nil { + s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot) + } + if cfg.EnableSSHSFTP != nil { + s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP) + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding) + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding) + } + if cfg.DisableSSHAuth != nil { + s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth) + } + if cfg.SSHJWTCacheTTL != nil { + s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL)) + } } if s.mNotifications == nil { @@ -1211,6 +1495,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.DisableServerRoutes = cfg.DisableServerRoutes config.BlockLANAccess = cfg.BlockLanAccess + config.EnableSSHRoot = &cfg.EnableSSHRoot + config.EnableSSHSFTP = &cfg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding + config.DisableSSHAuth = &cfg.DisableSSHAuth + + ttl := int(cfg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + return &config } @@ -1382,7 +1675,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1402,7 +1695,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1410,7 +1703,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1423,7 +1716,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) @@ -1487,6 +1780,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..51fa28575 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -493,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection( if uploadFailureReason != "" { showUploadFailedDialog(progress.window, localPath, uploadFailureReason) } else { - showUploadSuccessDialog(progress.window, localPath, uploadedKey) + showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey) } } else { showBundleCreatedDialog(progress.window, localPath) @@ -558,7 +565,7 @@ func (s *serviceClient) handleDebugCreation( if uploadFailureReason != "" { showUploadFailedDialog(w, localPath, uploadFailureReason) } else { - showUploadSuccessDialog(w, localPath, uploadedKey) + showUploadSuccessDialog(s.app, w, localPath, uploadedKey) } } else { showBundleCreatedDialog(w, localPath) @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } @@ -652,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) { } // showUploadSuccessDialog displays a dialog when upload succeeds -func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) { +func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) { log.Infof("Upload key: %s", uploadedKey) keyEntry := widget.NewEntry() keyEntry.SetText(uploadedKey) @@ -670,7 +683,7 @@ func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) { customDialog := dialog.NewCustom("Upload Successful", "OK", content, w) copyBtn := createButtonWithAction("Copy key", func() { - w.Clipboard().SetContent(uploadedKey) + a.Clipboard().SetContent(uploadedKey) log.Info("Upload key copied to clipboard") }) diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e9b7f4f30..9ffacd926 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx, true); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } @@ -148,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() { go func() { defer h.client.mAdvancedSettings.Enable() defer h.client.getSrvConfig() - h.runSelfCommand(h.client.ctx, "settings", "true") + h.runSelfCommand(h.client.ctx, "settings") }() } @@ -156,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() { h.client.mCreateDebugBundle.Disable() go func() { defer h.client.mCreateDebugBundle.Enable() - h.runSelfCommand(h.client.ctx, "debug", "true") + h.runSelfCommand(h.client.ctx, "debug") }() } @@ -180,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() { h.client.mNetworks.Disable() go func() { defer h.client.mNetworks.Enable() - h.runSelfCommand(h.client.ctx, "networks", "true") + h.runSelfCommand(h.client.ctx, "networks") }() } @@ -200,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error { return nil } -func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) { +func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) { proc, err := os.Executable() if err != nil { log.Errorf("error getting executable path: %v", err) return } - cmd := exec.CommandContext(ctx, proc, - fmt.Sprintf("--%s=%s", command, arg), + // Build the full command arguments + cmdArgs := []string{ + fmt.Sprintf("--%s=true", command), fmt.Sprintf("--daemon-addr=%s", h.client.addr), - ) + } + cmdArgs = append(cmdArgs, args...) + + cmd := exec.CommandContext(ctx, proc, cmdArgs...) if out := h.client.attachOutput(cmd); out != nil { defer func() { @@ -220,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) }() } - log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr) + log.Printf("running command: %s", cmd.String()) if err := cmd.Run(); err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode()) + log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode()) } return } - log.Printf("command '%s %s' completed successfully", command, arg) + log.Printf("command '%s' completed successfully", cmd.String()) } func (h *eventHandler) logout(ctx context.Context) error { @@ -245,6 +286,6 @@ func (h *eventHandler) logout(ctx context.Context) error { } h.client.getSrvConfig() - + return nil } diff --git a/client/ui/icons.go b/client/ui/icons.go index e88fb9378..874f24fdd 100644 --- a/client/ui/icons.go +++ b/client/ui/icons.go @@ -9,6 +9,9 @@ import ( //go:embed assets/netbird.png var iconAbout []byte +//go:embed assets/netbird-disconnected.png +var iconAboutDisconnected []byte + //go:embed assets/netbird-systemtray-connected.png var iconConnected []byte diff --git a/client/ui/icons_windows.go b/client/ui/icons_windows.go index 2107d3852..bd57b2690 100644 --- a/client/ui/icons_windows.go +++ b/client/ui/icons_windows.go @@ -7,6 +7,9 @@ import ( //go:embed assets/netbird.ico var iconAbout []byte +//go:embed assets/netbird-disconnected.ico +var iconAboutDisconnected []byte + //go:embed assets/netbird-systemtray-connected.ico var iconConnected []byte diff --git a/client/ui/process/process.go b/client/ui/process/process.go index d0ef54896..28276f416 100644 --- a/client/ui/process/process.go +++ b/client/ui/process/process.go @@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) { continue } - if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) { + runningProcessName := strings.ToLower(filepath.Base(runningProcessPath)) + if runningProcessName == processName && isProcessOwnedByCurrentUser(p) { return p.Pid, true, nil } } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..a38d8918a 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx, false); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go new file mode 100644 index 000000000..76440d684 --- /dev/null +++ b/client/ui/quickactions.go @@ -0,0 +1,349 @@ +//go:build !(linux && 386) + +//go:generate fyne bundle -o quickactions_assets.go assets/connected.png +//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png +package main + +import ( + "context" + _ "embed" + "fmt" + "runtime" + "sync/atomic" + "time" + + "fyne.io/fyne/v2" + "fyne.io/fyne/v2/canvas" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/layout" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" +) + +type quickActionsUiState struct { + connectionStatus string + isToggleButtonEnabled bool + isConnectionChanged bool + toggleAction func() +} + +func newQuickActionsUiState() quickActionsUiState { + return quickActionsUiState{ + connectionStatus: string(internal.StatusIdle), + isToggleButtonEnabled: false, + isConnectionChanged: false, + } +} + +type clientConnectionStatusProvider interface { + connectionStatus(ctx context.Context) (string, error) +} + +type daemonClientConnectionStatusProvider struct { + client proto.DaemonServiceClient +} + +func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) { + childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond) + defer cancel() + status, err := d.client.Status(childCtx, &proto.StatusRequest{}) + if err != nil { + return "", err + } + + return status.Status, nil +} + +type clientCommand interface { + execute() error +} + +type connectCommand struct { + connectClient func() error +} + +func (c connectCommand) execute() error { + return c.connectClient() +} + +type disconnectCommand struct { + disconnectClient func() error +} + +func (c disconnectCommand) execute() error { + return c.disconnectClient() +} + +type quickActionsViewModel struct { + provider clientConnectionStatusProvider + connect clientCommand + disconnect clientCommand + uiChan chan quickActionsUiState + isWatchingConnectionStatus atomic.Bool +} + +func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) { + viewModel := quickActionsViewModel{ + provider: provider, + connect: connect, + disconnect: disconnect, + uiChan: uiChan, + } + + viewModel.isWatchingConnectionStatus.Store(true) + + // base UI status + uiChan <- newQuickActionsUiState() + + // this retrieves the current connection status + // and pushes the UI state that reflects it via uiChan + go viewModel.watchConnectionStatus(ctx) +} + +func (q *quickActionsViewModel) updateUiState(ctx context.Context) { + uiState := newQuickActionsUiState() + connectionStatus, err := q.provider.connectionStatus(ctx) + + if err != nil { + log.Errorf("Status: Error - %v", err) + q.uiChan <- uiState + return + } + + if connectionStatus == string(internal.StatusConnected) { + uiState.toggleAction = func() { + q.executeCommand(q.disconnect) + } + } else { + uiState.toggleAction = func() { + q.executeCommand(q.connect) + } + } + + uiState.isToggleButtonEnabled = true + uiState.connectionStatus = connectionStatus + q.uiChan <- uiState +} + +func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) { + ticker := time.NewTicker(1000 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if q.isWatchingConnectionStatus.Load() { + q.updateUiState(ctx) + } + } + } +} + +func (q *quickActionsViewModel) executeCommand(command clientCommand) { + uiState := newQuickActionsUiState() + // newQuickActionsUiState starts with Idle connection status, + // and all that's necessary here is to just disable the toggle button. + uiState.connectionStatus = "" + + q.uiChan <- uiState + + q.isWatchingConnectionStatus.Store(false) + + err := command.execute() + + if err != nil { + log.Errorf("Status: Error - %v", err) + q.isWatchingConnectionStatus.Store(true) + } else { + uiState = newQuickActionsUiState() + uiState.isConnectionChanged = true + q.uiChan <- uiState + } +} + +func getSystemTrayName() string { + os := runtime.GOOS + switch os { + case "darwin": + return "menu bar" + default: + return "system tray" + } +} + +func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image { + imageSize := fyne.NewSize(64, 64) + + resource := fyne.NewStaticResource(name, content) + image := canvas.NewImageFromResource(resource) + image.FillMode = canvas.ImageFillContain + image.SetMinSize(imageSize) + image.Resize(imageSize) + + return image +} + +type quickActionsUiComponents struct { + content *fyne.Container + toggleConnectionButton *widget.Button + connectedLabelText, disconnectedLabelText string + connectedImage, disconnectedImage *canvas.Image + connectedCircleRes, disconnectedCircleRes fyne.Resource +} + +// applyQuickActionsUiState applies a single UI state to the quick actions window. +// It closes the window and returns true if the connection status has changed, +// in which case the caller should stop processing further states. +func (s *serviceClient) applyQuickActionsUiState( + uiState quickActionsUiState, + components quickActionsUiComponents, +) bool { + if uiState.isConnectionChanged { + fyne.DoAndWait(func() { + s.wQuickActions.Close() + }) + return true + } + + var logo *canvas.Image + var buttonText string + var buttonIcon fyne.Resource + + if uiState.connectionStatus == string(internal.StatusConnected) { + buttonText = components.connectedLabelText + buttonIcon = components.connectedCircleRes + logo = components.connectedImage + } else if uiState.connectionStatus == string(internal.StatusIdle) { + buttonText = components.disconnectedLabelText + buttonIcon = components.disconnectedCircleRes + logo = components.disconnectedImage + } + + fyne.DoAndWait(func() { + if buttonText != "" { + components.toggleConnectionButton.SetText(buttonText) + } + + if buttonIcon != nil { + components.toggleConnectionButton.SetIcon(buttonIcon) + } + + if uiState.isToggleButtonEnabled { + components.toggleConnectionButton.Enable() + } else { + components.toggleConnectionButton.Disable() + } + + components.toggleConnectionButton.OnTapped = func() { + if uiState.toggleAction != nil { + go uiState.toggleAction() + } + } + + components.toggleConnectionButton.Refresh() + + // the second position in the content's object array is the NetBird logo. + if logo != nil { + components.content.Objects[1] = logo + components.content.Refresh() + } + }) + + return false +} + +// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button. +func (s *serviceClient) showQuickActionsUI() { + s.wQuickActions = s.app.NewWindow("NetBird") + vmCtx, vmCancel := context.WithCancel(s.ctx) + s.wQuickActions.SetOnClosed(vmCancel) + + client, err := s.getSrvClient(defaultFailTimeout) + + connCmd := connectCommand{ + connectClient: func() error { + return s.menuUpClick(s.ctx, false) + }, + } + + disConnCmd := disconnectCommand{ + disconnectClient: func() error { + return s.menuDownClick() + }, + } + + if err != nil { + log.Errorf("get service client: %v", err) + return + } + + uiChan := make(chan quickActionsUiState, 1) + newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan) + + connectedImage := s.getNetBirdImage("netbird.png", iconAbout) + disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected) + + connectedCircle := canvas.NewImageFromResource(resourceConnectedPng) + disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng) + + connectedLabelText := "Disconnect" + disconnectedLabelText := "Connect" + + toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() { + // This button's tap function will be set when an ui state arrives via the uiChan channel. + }) + + // Button starts disabled until the first ui state arrives. + toggleConnectionButton.Disable() + + hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName()) + hintLabel := widget.NewLabel(hintLabelText) + + content := container.NewVBox( + layout.NewSpacer(), + disconnectedImage, + layout.NewSpacer(), + container.NewCenter(toggleConnectionButton), + layout.NewSpacer(), + container.NewCenter(hintLabel), + ) + + // this watches for ui state updates. + go func() { + + for { + select { + case <-vmCtx.Done(): + return + case uiState, ok := <-uiChan: + if !ok { + return + } + + closed := s.applyQuickActionsUiState( + uiState, + quickActionsUiComponents{ + content, + toggleConnectionButton, + connectedLabelText, disconnectedLabelText, + connectedImage, disconnectedImage, + connectedCircle.Resource, disconnectedCircle.Resource, + }, + ) + if closed { + return + } + } + } + }() + + s.wQuickActions.SetContent(content) + s.wQuickActions.Resize(fyne.NewSize(400, 200)) + s.wQuickActions.SetFixedSize(true) + s.wQuickActions.Show() +} diff --git a/client/ui/quickactions_assets.go b/client/ui/quickactions_assets.go new file mode 100644 index 000000000..9ff5e85a2 --- /dev/null +++ b/client/ui/quickactions_assets.go @@ -0,0 +1,23 @@ +// auto-generated +// Code generated by '$ fyne bundle'. DO NOT EDIT. + +package main + +import ( + _ "embed" + "fyne.io/fyne/v2" +) + +//go:embed assets/connected.png +var resourceConnectedPngData []byte +var resourceConnectedPng = &fyne.StaticResource{ + StaticName: "assets/connected.png", + StaticContent: resourceConnectedPngData, +} + +//go:embed assets/disconnected.png +var resourceDisconnectedPngData []byte +var resourceDisconnectedPng = &fyne.StaticResource{ + StaticName: "assets/disconnected.png", + StaticContent: resourceDisconnectedPngData, +} diff --git a/client/ui/signal_unix.go b/client/ui/signal_unix.go new file mode 100644 index 000000000..99de99f0f --- /dev/null +++ b/client/ui/signal_unix.go @@ -0,0 +1,76 @@ +//go:build !windows && !(linux && 386) + +package main + +import ( + "context" + "os" + "os/exec" + "os/signal" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// setupSignalHandler sets up a signal handler to listen for SIGUSR1. +// When received, it opens the quick actions window. +func (s *serviceClient) setupSignalHandler(ctx context.Context) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGUSR1) + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-sigChan: + log.Info("received SIGUSR1 signal, opening quick actions window") + s.openQuickActions() + } + } + }() +} + +// openQuickActions opens the quick actions window by spawning a new process. +func (s *serviceClient) openQuickActions() { + proc, err := os.Executable() + if err != nil { + log.Errorf("get executable path: %v", err) + return + } + + cmd := exec.CommandContext(s.ctx, proc, + "--quick-actions=true", + "--daemon-addr="+s.addr, + ) + + if out := s.attachOutput(cmd); out != nil { + defer func() { + if err := out.Close(); err != nil { + log.Errorf("close log file %s: %v", s.logFile, err) + } + }() + } + + log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr) + + if err := cmd.Start(); err != nil { + log.Errorf("start quick actions window: %v", err) + return + } + + go func() { + if err := cmd.Wait(); err != nil { + log.Debugf("quick actions window exited: %v", err) + } + }() +} + +// sendShowWindowSignal sends SIGUSR1 to the specified PID. +func sendShowWindowSignal(pid int32) error { + process, err := os.FindProcess(int(pid)) + if err != nil { + return err + } + return process.Signal(syscall.SIGUSR1) +} diff --git a/client/ui/signal_windows.go b/client/ui/signal_windows.go new file mode 100644 index 000000000..ca98be526 --- /dev/null +++ b/client/ui/signal_windows.go @@ -0,0 +1,171 @@ +//go:build windows + +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent` + waitTimeout = 5 * time.Second + // SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent. + desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE +) + +func getEventNameUint16Pointer() (*uint16, error) { + eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName) + if err != nil { + log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err) + return nil, err + } + + return eventNamePtr, nil +} + +// setupSignalHandler sets up signal handling for Windows. +// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events. +func (s *serviceClient) setupSignalHandler(ctx context.Context) { + eventNamePtr, err := getEventNameUint16Pointer() + if err != nil { + return + } + + eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr) + + if err != nil { + if errors.Is(err, windows.ERROR_ALREADY_EXISTS) { + log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName) + eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr) + if err != nil { + log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err) + return + } + log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName) + } else { + log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err) + return + } + } + + if eventHandle == windows.InvalidHandle { + log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName) + return + } + + log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName) + + go s.waitForEvent(ctx, eventHandle) +} + +func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) { + defer func() { + if err := windows.CloseHandle(eventHandle); err != nil { + log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err) + } + }() + + for { + if ctx.Err() != nil { + return + } + + status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds())) + + switch status { + case windows.WAIT_OBJECT_0: + log.Info("Received signal on quick actions event. Opening quick actions window.") + + // reset the event so it can be triggered again later (manual reset == 1) + if err := windows.ResetEvent(eventHandle); err != nil { + log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err) + } + + s.openQuickActions() + case uint32(windows.WAIT_TIMEOUT): + + default: + if isDone := logUnexpectedStatus(ctx, status, err); isDone { + return + } + } + } +} + +func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool { + log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v", + status, quickActionsTriggerEventName, err) + select { + case <-time.After(5 * time.Second): + return false + case <-ctx.Done(): + return true + } +} + +// openQuickActions opens the quick actions window by spawning a new process. +func (s *serviceClient) openQuickActions() { + proc, err := os.Executable() + if err != nil { + log.Errorf("get executable path: %v", err) + return + } + + cmd := exec.CommandContext(s.ctx, proc, + "--quick-actions=true", + "--daemon-addr="+s.addr, + ) + + if out := s.attachOutput(cmd); out != nil { + defer func() { + if err := out.Close(); err != nil { + log.Errorf("close log file %s: %v", s.logFile, err) + } + }() + } + + log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr) + + if err := cmd.Start(); err != nil { + log.Errorf("error starting quick actions window: %v", err) + return + } + + go func() { + if err := cmd.Wait(); err != nil { + log.Debugf("quick actions window exited: %v", err) + } + }() +} + +func sendShowWindowSignal(pid int32) error { + _, err := os.FindProcess(int(pid)) + if err != nil { + return err + } + + eventNamePtr, err := getEventNameUint16Pointer() + if err != nil { + return err + } + + eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr) + if err != nil { + return err + } + + err = windows.SetEvent(eventHandle) + if err != nil { + return fmt.Errorf("Error setting event: %w", err) + } + + return nil +} diff --git a/client/ui/update.go b/client/ui/update.go new file mode 100644 index 000000000..25c317bdf --- /dev/null +++ b/client/ui/update.go @@ -0,0 +1,140 @@ +//go:build !(linux && 386) + +package main + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) { + log.Infof("show installer progress window: %s", version) + s.wUpdateProgress = s.app.NewWindow("Automatically updating client") + + statusLabel := widget.NewLabel("Updating...") + infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version)) + content := container.NewVBox(infoLabel, statusLabel) + s.wUpdateProgress.SetContent(content) + s.wUpdateProgress.CenterOnScreen() + s.wUpdateProgress.SetFixedSize(true) + s.wUpdateProgress.SetCloseIntercept(func() { + // this is empty to lock window until result known + }) + s.wUpdateProgress.RequestFocus() + s.wUpdateProgress.Show() + + updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + + // Initialize dot updater + updateText := dotUpdater() + + // Channel to receive the result from RPC call + resultErrCh := make(chan error, 1) + resultOkCh := make(chan struct{}, 1) + + // Start RPC call in background + go func() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Infof("backend not reachable, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{}) + if err != nil { + log.Infof("backend stopped responding, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + if !resp.Success { + resultErrCh <- mapInstallError(resp.ErrorMsg) + return + } + + // Success + close(resultOkCh) + }() + + // Update UI with dots and wait for result + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + defer cancel() + + // allow closing update window after 10 sec + timerResetCloseInterceptor := time.NewTimer(10 * time.Second) + defer timerResetCloseInterceptor.Stop() + + for { + select { + case <-updateWindowCtx.Done(): + s.showInstallerResult(statusLabel, updateWindowCtx.Err()) + return + case err := <-resultErrCh: + s.showInstallerResult(statusLabel, err) + return + case <-resultOkCh: + log.Info("backend exited, upgrade in progress, closing all UI") + killParentUIProcess() + s.app.Quit() + return + case <-ticker.C: + statusLabel.SetText(updateText()) + case <-timerResetCloseInterceptor.C: + s.wUpdateProgress.SetCloseIntercept(nil) + } + } + }() +} + +func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) { + s.wUpdateProgress.SetCloseIntercept(nil) + switch { + case errors.Is(err, context.DeadlineExceeded): + log.Warn("update watcher timed out") + statusLabel.SetText("Update timed out. Please try again.") + case errors.Is(err, context.Canceled): + log.Info("update watcher canceled") + statusLabel.SetText("Update canceled.") + case err != nil: + log.Errorf("update failed: %v", err) + statusLabel.SetText("Update failed: " + err.Error()) + default: + s.wUpdateProgress.Close() + } +} + +// dotUpdater returns a closure that cycles through dots for a loading animation. +func dotUpdater() func() string { + dotCount := 0 + return func() string { + dotCount = (dotCount + 1) % 4 + return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount)) + } +} + +func mapInstallError(msg string) error { + msg = strings.ToLower(strings.TrimSpace(msg)) + + switch { + case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"): + return context.DeadlineExceeded + case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"): + return context.Canceled + case msg == "": + return errors.New("unknown update error") + default: + return errors.New(msg) + } +} diff --git a/client/ui/update_notwindows.go b/client/ui/update_notwindows.go new file mode 100644 index 000000000..5766f18f7 --- /dev/null +++ b/client/ui/update_notwindows.go @@ -0,0 +1,7 @@ +//go:build !windows && !(linux && 386) + +package main + +func killParentUIProcess() { + // No-op on non-Windows platforms +} diff --git a/client/ui/update_windows.go b/client/ui/update_windows.go new file mode 100644 index 000000000..1b03936f9 --- /dev/null +++ b/client/ui/update_windows.go @@ -0,0 +1,44 @@ +//go:build windows + +package main + +import ( + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + + nbprocess "github.com/netbirdio/netbird/client/ui/process" +) + +// killParentUIProcess finds and kills the parent systray UI process on Windows. +// This is a workaround in case the MSI installer fails to properly terminate the UI process. +// The installer should handle this via util:CloseApplication with TerminateProcess, but this +// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds. +func killParentUIProcess() { + pid, running, err := nbprocess.IsAnotherProcessRunning() + if err != nil { + log.Warnf("failed to check for parent UI process: %v", err) + return + } + + if !running { + log.Debug("no parent UI process found to kill") + return + } + + log.Infof("killing parent UI process (PID: %d)", pid) + + // Open the process with terminate rights + handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid)) + if err != nil { + log.Warnf("failed to open parent process %d: %v", pid, err) + return + } + defer func() { + _ = windows.CloseHandle(handle) + }() + + // Terminate the process with exit code 0 + if err := windows.TerminateProcess(handle, 0); err != nil { + log.Warnf("failed to terminate parent process %d: %v", pid, err) + } +} diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d542e2739..238e272fa 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" netbird "github.com/netbirdio/netbird/client/embed" + sshdetection "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -18,9 +19,10 @@ import ( ) const ( - clientStartTimeout = 30 * time.Second - clientStopTimeout = 10 * time.Second - defaultLogLevel = "warn" + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" + defaultSSHDetectionTimeout = 20 * time.Second ) func main() { @@ -125,10 +127,15 @@ func createSSHMethod(client *netbird.Client) js.Func { username = args[2].String() } + var jwtToken string + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + return createPromise(func(resolve, reject js.Value) { sshClient := ssh.NewClient(client) - if err := sshClient.Connect(host, port, username); err != nil { + if err := sshClient.Connect(host, port, username, jwtToken); err != nil { reject.Invoke(err.Error()) return } @@ -191,12 +198,46 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value { })) } +// createDetectSSHServerMethod creates the SSH server detection method +func createDetectSSHServerMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + + timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds()) + if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() { + timeoutMs = args[2].Int() + if timeoutMs <= 0 { + return js.ValueOf("error: timeout must be positive") + } + } + + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + + serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port) + if err != nil { + reject.Invoke(err.Error()) + return + } + + resolve.Invoke(js.ValueOf(serverType.RequiresJWT())) + }) + }) +} + // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{}) obj["start"] = createStartMethod(client) obj["stop"] = createStopMethod(client) + obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["createSSHConnection"] = createSSHMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index ca35525eb..568437e56 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" netbird "github.com/netbirdio/netbird/client/embed" + nbssh "github.com/netbirdio/netbird/client/ssh" ) const ( @@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client { } // Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username string) error { +func (c *Client) Connect(host string, port int, username, jwtToken string) error { addr := fmt.Sprintf("%s:%d", host, port) logrus.Infof("SSH: Connecting to %s as %s", addr, username) - var authMethods []ssh.AuthMethod - - nbConfig, err := c.nbClient.GetConfig() + authMethods, err := c.getAuthMethods(jwtToken) if err != nil { - return fmt.Errorf("get NetBird config: %w", err) + return err } - if nbConfig.SSHKey == "" { - return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") - } - - signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) - if err != nil { - return fmt.Errorf("parse NetBird SSH private key: %w", err) - } - - pubKey := signer.PublicKey() - logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) - - authMethods = append(authMethods, ssh.PublicKeys(signer)) config := &ssh.ClientConfig{ User: username, Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient), Timeout: sshDialTimeout, } @@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error { return nil } +// getAuthMethods returns SSH authentication methods, preferring JWT if available +func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) { + if jwtToken != "" { + logrus.Debugf("SSH: Using JWT password authentication") + return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil + } + + logrus.Debugf("SSH: No JWT token, using public key authentication") + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return nil, fmt.Errorf("get NetBird config: %w", err) + } + + if nbConfig.SSHKey == "" { + return nil, fmt.Errorf("no NetBird SSH key available") + } + + signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return nil, fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + logrus.Debugf("SSH: Added public key auth") + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil +} + // StartSession starts an SSH session with PTY func (c *Client) StartSession(cols, rows int) error { if c.sshClient == nil { diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go deleted file mode 100644 index 4868ba30a..000000000 --- a/client/wasm/internal/ssh/key.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build js - -package ssh - -import ( - "crypto/x509" - "encoding/pem" - "fmt" - "strings" - - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" -) - -// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format -func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { - keyStr := string(keyPEM) - if !strings.Contains(keyStr, "-----BEGIN") { - keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") - } - - signer, err := ssh.ParsePrivateKey(keyPEM) - if err == nil { - return signer, nil - } - logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) - - block, _ := pem.Decode(keyPEM) - if block == nil { - keyPreview := string(keyPEM) - if len(keyPreview) > 100 { - keyPreview = keyPreview[:100] - } - return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) - } - - key, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) - if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(rsaKey) - } - if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(ecKey) - } - return nil, fmt.Errorf("parse private key: %w", err) - } - - return ssh.NewSignerFromKey(key) -} diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..aa0e16eb1 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" @@ -31,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server @@ -39,6 +45,10 @@ type CustomZone struct { Domain string // Records custom zone records Records []SimpleRecord + // SearchDomainDisabled indicates whether to add match domains to a search domains list or not + SearchDomainDisabled bool + // SkipPTRProcess indicates whether a client should process PTR records from custom zones + SkipPTRProcess bool } // SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records diff --git a/formatter/hook/hook.go b/formatter/hook/hook.go index c0d8c4eba..f0ee509f8 100644 --- a/formatter/hook/hook.go +++ b/formatter/hook/hook.go @@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error { entry.Data["context"] = source - switch source { - case HTTPSource: - addHTTPFields(entry) - case GRPCSource: - addGRPCFields(entry) - case SystemSource: - addSystemFields(entry) - } + addFields(entry) return nil } @@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string { return fmt.Sprintf("%s/%s", pkg, file) } -func addHTTPFields(entry *logrus.Entry) { +func addFields(entry *logrus.Entry) { if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { entry.Data[context.RequestIDKey] = ctxReqID } @@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) { if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { entry.Data[context.UserIDKey] = ctxInitiatorID } -} - -func addGRPCFields(entry *logrus.Entry) { - if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { - entry.Data[context.RequestIDKey] = ctxReqID - } - if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { - entry.Data[context.AccountIDKey] = ctxAccountID - } - if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { - entry.Data[context.PeerIDKey] = ctxDeviceID - } -} - -func addSystemFields(entry *logrus.Entry) { - if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { - entry.Data[context.RequestIDKey] = ctxReqID - } - if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { - entry.Data[context.UserIDKey] = ctxInitiatorID - } - if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { - entry.Data[context.AccountIDKey] = ctxAccountID - } if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { entry.Data[context.PeerIDKey] = ctxDeviceID } diff --git a/go.mod b/go.mod index a1560b409..8f4ec530b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.23.0 +go 1.24.10 require ( cunicu.li/go-rosenpass v0.4.0 @@ -16,9 +16,9 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.40.0 - golang.org/x/sys v0.34.0 + github.com/vishvananda/netlink v1.3.1 + golang.org/x/crypto v0.45.0 + golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -28,9 +28,10 @@ require ( ) require ( - fyne.io/fyne/v2 v2.5.3 - fyne.io/systray v1.11.0 + fyne.io/fyne/v2 v2.7.0 + fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible + github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 @@ -43,7 +44,7 @@ require ( github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 - github.com/fsnotify/fsnotify v1.7.0 + github.com/fsnotify/fsnotify v1.9.0 github.com/gliderlabs/ssh v0.3.8 github.com/godbus/dbus/v5 v5.1.0 github.com/golang-jwt/jwt/v5 v5.3.0 @@ -56,13 +57,14 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 + github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -75,14 +77,15 @@ require ( github.com/pion/stun/v3 v3.0.0 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 + github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.22.0 - github.com/quic-go/quic-go v0.48.2 + github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 @@ -98,15 +101,17 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.25.0 - golang.org/x/net v0.42.0 - golang.org/x/oauth2 v0.28.0 - golang.org/x/sync v0.16.0 - golang.org/x/term v0.33.0 + golang.org/x/mobile v0.0.0-20251113184115-a159579294ab + golang.org/x/mod v0.30.0 + golang.org/x/net v0.47.0 + golang.org/x/oauth2 v0.30.0 + golang.org/x/sync v0.18.0 + golang.org/x/term v0.37.0 + golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -123,10 +128,11 @@ require ( dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/BurntSushi/toml v1.4.0 // indirect + github.com/BurntSushi/toml v1.5.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/awnumar/memcall v0.4.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect @@ -146,7 +152,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/containerd v1.7.27 // indirect + github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect @@ -157,11 +163,12 @@ require ( github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fredbi/uri v1.1.0 // indirect - github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect - github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect - github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect - github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect + github.com/fredbi/uri v1.1.1 // indirect + github.com/fyne-io/gl-js v0.2.0 // indirect + github.com/fyne-io/glfw-js v0.3.0 // indirect + github.com/fyne-io/image v0.1.1 // indirect + github.com/fyne-io/oksvg v0.2.0 // indirect + github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -169,7 +176,7 @@ require ( github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.2.0 // indirect - github.com/go-text/typesetting v0.2.0 // indirect + github.com/go-text/typesetting v0.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect @@ -177,22 +184,23 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect - github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/hack-pad/go-indexeddb v0.3.2 // indirect + github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect + github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect + github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -208,7 +216,8 @@ require ( github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect + github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect + github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect @@ -224,29 +233,27 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/rymdport/portal v0.3.0 // indirect + github.com/rymdport/portal v0.4.2 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.3 // indirect - github.com/yuin/goldmark v1.7.1 // indirect + github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/image v0.18.0 // indirect - golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.34.0 // indirect + golang.org/x/image v0.33.0 // indirect + golang.org/x/text v0.31.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/go.sum b/go.sum index 13838b82d..f10e1e6da 100644 --- a/go.sum +++ b/go.sum @@ -1,67 +1,28 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= -cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk= -cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= -cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= -cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs= cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -fyne.io/fyne/v2 v2.5.3 h1:k6LjZx6EzRZhClsuzy6vucLZBstdH2USDGHSGWq8ly8= -fyne.io/fyne/v2 v2.5.3/go.mod h1:0GOXKqyvNwk3DLmsFu9v0oYM0ZcD1ysGnlHCerKoAmo= -fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= -fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= +fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= +fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI= +fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= -github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0= @@ -70,10 +31,10 @@ github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJT github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= -github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= +github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= +github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= +github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= @@ -114,8 +75,6 @@ github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -138,23 +97,18 @@ github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= -github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= +github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE= +github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= -github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= @@ -182,37 +136,31 @@ github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= -github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fredbi/uri v1.1.0 h1:OqLpTXtyRg9ABReqvDGdJPqZUxs8cyBDOMXBbskCaB8= -github.com/fredbi/uri v1.1.0/go.mod h1:aYTUoAXBOq7BLfVJ8GnKmfcuURosB1xyHDIfWeC/iW4= +github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko= +github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe h1:A/wiwvQ0CAjPkuJytaD+SsXkPU0asQ+guQEIg1BJGX4= -github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe/go.mod h1:d4clgH0/GrRwWjRzJJQXxT/h1TyuNSfF/X64zb/3Ggg= -github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 h1:/1YRWFv9bAWkoo3SuxpFfzpXH0D/bQnTjNXyF4ih7Os= -github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY= -github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 h1:hnLq+55b7Zh7/2IRzWCpiTcAvjv/P8ERF+N7+xXbZhk= -github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2/go.mod h1:eO7W361vmlPOrykIg+Rsh1SZ3tQBaOsfzZhsIOb/Lm0= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs= +github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI= +github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk= +github.com/fyne-io/glfw-js v0.3.0/go.mod h1:Ri6te7rdZtBgBpxLW19uBpp3Dl6K9K/bRaYdJ22G8Jk= +github.com/fyne-io/image v0.1.1 h1:WH0z4H7qfvNUw5l4p3bC1q70sa5+YWVt6HCj7y4VNyA= +github.com/fyne-io/image v0.1.1/go.mod h1:xrfYBh6yspc+KjkgdZU/ifUC9sPA5Iv7WYUBzQKK7JM= +github.com/fyne-io/oksvg v0.2.0 h1:mxcGU2dx6nwjJsSA9PCYZDuoAcsZ/OuJlvg/Q9Njfo8= +github.com/fyne-io/oksvg v0.2.0/go.mod h1:dJ9oEkPiWhnTFNCmRgEze+YNprJF7YRbpjgpWS4kzoI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= -github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 h1:zDw5v7qm4yH7N8C8uWd+8Ii9rROdgWxQuGoJ9WDXxfk= -github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw2H/zl9nrd3eCZfYV+NfQA= +github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -237,11 +185,10 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= -github.com/go-text/typesetting v0.2.0 h1:fbzsgbmk04KiWtE+c3ZD4W2nmCRzBqrqQOvYlwAOdho= -github.com/go-text/typesetting v0.2.0/go.mod h1:2+owI/sxa73XA581LAzVuEBZ3WEEV2pXeDswCH/3i1I= -github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66 h1:GUrm65PQPlhFSKjLPGOZNPNxLCybjzjYBzjfoBGaDUY= -github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8= +github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= +github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= +github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -249,27 +196,15 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -279,25 +214,18 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -308,25 +236,10 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -334,61 +247,34 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20211219123610-ec9572f70e60/go.mod h1:cz9oNYuRUWGdHmLF2IodMLkAhcPtXeULvcBNagUrxTI= -github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= -github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/goxjs/gl v0.0.0-20210104184919-e3fafc6f8f2a/go.mod h1:dy/f2gjY09hwVfIyATps4G2ai7/hLwLkc5TrPqONuXY= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= -github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= -github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= -github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hack-pad/go-indexeddb v0.3.2 h1:DTqeJJYc1usa45Q5r52t01KhvlSN02+Oq+tQbSBI91A= +github.com/hack-pad/go-indexeddb v0.3.2/go.mod h1:QvfTevpDVlkfomY498LhstjwbPW6QC4VC/lxYb0Kom0= +github.com/hack-pad/safejs v0.1.0 h1:qPS6vjreAqh2amUqj4WNG1zIw7qlRQJ9K10eDKMCnE8= +github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xNCnb+Tio= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= -github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= -github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= -github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -399,8 +285,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 h1:Po+wkNdMmN+Zj1tDsJQy7mJlPlwGNQd9JZoPjObagf8= -github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49/go.mod h1:YiutDnxPRLk5DLUFj6Rw4pRBBURZY07GFr54NdV9mQg= +github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE= +github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -410,12 +296,8 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk= -github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= +github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -425,12 +307,10 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -442,14 +322,13 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= -github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= @@ -461,21 +340,12 @@ github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= -github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= -github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= @@ -490,29 +360,26 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= -github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= -github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= -github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= -github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= -github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= +github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -534,10 +401,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= @@ -569,15 +434,14 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= -github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= +github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= +github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= @@ -590,51 +454,35 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/rymdport/portal v0.3.0 h1:QRHcwKwx3kY5JTQcsVhmhC3TGqGQb9LFghVNUy8AdB8= -github.com/rymdport/portal v0.3.0/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= +github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= -github.com/shurcooL/go v0.0.0-20200502201357-93f07166e636/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= -github.com/shurcooL/httpfs v0.0.0-20190707220628-8d4bc4ba7749/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/vfsgen v0.0.0-20200824052919-0d455de96546/go.mod h1:TrYk7fJVaAttu97ZZKrO9UbRa8izdowaMIZcxYMbVaw= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= -github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= -github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v1.2.1/go.mod h1:ExllRjgxM/piMAM+3tAZvg8fsklGAf3tPfi+i8t68Nk= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= @@ -644,7 +492,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -656,9 +503,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U= github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI= github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk= @@ -681,25 +527,22 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= -github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= -github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zcalusic/sysinfo v1.1.3 h1:u/AVENkuoikKuIZ4sUEJ6iibpmQP6YpGD8SSMCrqAF0= @@ -710,16 +553,6 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= -go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= -go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= -go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= @@ -746,211 +579,113 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= -golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= +golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= +golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mobile v0.0.0-20211207041440-4e6c2922fdee/go.mod h1:pe2sM7Uk+2Su1y7u/6Z8KJ24D7lepUjFZbhFOrmDfuQ= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -962,98 +697,60 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= -golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= -golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.8-0.20211022200916-316ba0b74098/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= -golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1064,103 +761,24 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= -google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= -google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= -google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= -google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= -google.golang.org/api v0.44.0/go.mod h1:EBOGZqzyhtvMDoxwS97ctnh0zUmYY6CxqXsc1AvkYD8= google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk= google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= -google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= -google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -1171,7 +789,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= @@ -1180,14 +797,11 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= -gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= @@ -1197,14 +811,12 @@ gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= @@ -1221,12 +833,4 @@ gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..1c9c63f78 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,22 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console", + "--port", "80" + ] # Relay relay: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 196e26a66..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -49,6 +49,7 @@ services: - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/controllers/network_map/controller/cache/dns_config_cache.go b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go new file mode 100644 index 000000000..8cc634ef4 --- /dev/null +++ b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go @@ -0,0 +1,31 @@ +package cache + +import ( + "sync" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// DNSConfigCache is a thread-safe cache for DNS configuration components +type DNSConfigCache struct { + NameServerGroups sync.Map +} + +// GetNameServerGroup retrieves a cached name server group +func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { + if c == nil { + return nil, false + } + if value, ok := c.NameServerGroups.Load(key); ok { + return value.(*proto.NameServerGroup), true + } + return nil, false +} + +// SetNameServerGroup stores a name server group in the cache +func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { + if c == nil { + return + } + c.NameServerGroups.Store(key, value) +} diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go new file mode 100644 index 000000000..df16e1922 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller.go @@ -0,0 +1,829 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "os" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "golang.org/x/mod/semver" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util" +) + +type Controller struct { + repo Repository + metrics *metrics + // This should not be here, but we need to maintain it for the time being + accountManagerMetrics *telemetry.AccountManagerMetrics + peersUpdateManager network_map.PeersUpdateManager + settingsManager settings.Manager + EphemeralPeersManager ephemeral.Manager + + accountUpdateLocks sync.Map + sendAccountUpdateLocks sync.Map + updateAccountPeersBufferInterval atomic.Int64 + // dnsDomain is used for peer resolution. This is appended to the peer's name + dnsDomain string + config *config.Config + + requestBuffer account.RequestBuffer + + proxyController port_forwarding.Controller + + integratedPeerValidator integrated_validator.IntegratedValidator + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} +} + +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} + +var _ network_map.Controller = (*Controller)(nil) + +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller { + nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) + if err != nil { + log.Fatal(fmt.Errorf("error creating metrics: %w", err)) + } + + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + + return &Controller{ + repo: newRepository(store), + metrics: nMetrics, + accountManagerMetrics: metrics.AccountManagerMetrics(), + peersUpdateManager: peersUpdateManager, + requestBuffer: requestBuffer, + integratedPeerValidator: integratedPeerValidator, + settingsManager: settingsManager, + dnsDomain: dnsDomain, + config: config, + + proxyController: proxyController, + EphemeralPeersManager: ephemeralPeersManager, + + holder: types.NewHolder(), + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, + } +} + +func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err) + } + + c.EphemeralPeersManager.OnPeerConnected(ctx, peer) + + return c.peersUpdateManager.CreateChannel(ctx, peerID), nil +} + +func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) { + c.peersUpdateManager.CloseChannel(ctx, peerID) + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err) + return + } + c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) +} + +func (c *Controller) CountStreams() int { + return c.peersUpdateManager.CountStreams() +} + +func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) + } + } + + globalStart := time.Now() + + hasPeersConnected := false + for _, peer := range account.Peers { + if c.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return nil + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validate peers: %v", err) + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + if c.experimentalNetworkMap(accountID) { + c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return fmt.Errorf("failed to get proxy network maps: %v", err) + } + + extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get flow enabled status: %v", err) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + for _, peer := range account.Peers { + if !c.peersUpdateManager.HasChannel(peer.ID) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + continue + } + + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + start := time.Now() + + postureChecks, err := c.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err) + return + } + + c.metrics.CountCalcPostureChecksDuration(time.Since(start)) + start = time.Now() + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } + + c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + peerGroups := account.GetPeerGroups(p.ID) + start = time.Now() + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + c.metrics.CountToSyncResponseDuration(time.Since(start)) + + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + }(peer) + } + + wg.Wait() + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart)) + } + + return nil +} + +func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.sendUpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.sendUpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +// UpdatePeers updates all peers that belong to an account. +// Should be called when changes have to be synced to peers. +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { + if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return fmt.Errorf("recalculate network map cache: %v", err) + } + + return c.sendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { + if !c.peersUpdateManager.HasChannel(peerId) { + return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) + } + + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) + } + + peer := account.GetPeer(peerId) + if peer == nil { + return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId) + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validated peers: %v", err) + } + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + postureChecks, err := c.getPeerPostureChecks(account, peerId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) + return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return err + } + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get extra settings: %v", err) + } + + peerGroups := account.GetPeerGroups(peerId) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + + return nil +} + +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + if isRequiresApproval { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + + emptyMap := &types.NetworkMap{ + Network: network.Copy(), + } + return peer, emptyMap, nil, 0, nil + } + + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, nil, nil, 0, err + } + + startPosture := time.Now() + postureChecks, err := c.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, 0, err + } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) + + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, nil, nil, 0, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers()) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + return peer, networkMap, postureChecks, dnsFwdPort, nil +} + +func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + c.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (c *Controller) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := c.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := c.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + c.updateAccountInHolder(account) +} + +func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if c.experimentalNetworkMap(accountId) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + c.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (c *Controller) experimentalNetworkMap(accountId string) bool { + _, ok := c.expNewNetworkMapAIDs[accountId] + return c.expNewNetworkMap || ok +} + +func (c *Controller) enrichAccountFromHolder(account *types.Account) { + a := c.holder.GetAccount(account.Id) + if a == nil { + c.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + c.holder.AddAccount(account) +} + +func (c *Controller) getAccountFromHolder(accountID string) *types.Account { + return c.holder.GetAccount(accountID) +} + +func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account { + a := c.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (c *Controller) updateAccountInHolder(account *types.Account) { + c.holder.AddAccount(account) +} + +// GetDNSDomain returns the configured dnsDomain +func (c *Controller) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return c.dnsDomain + } + if settings.DNSDomain == "" { + return c.dnsDomain + } + + return settings.DNSDomain +} + +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err + } + } + + return maps.Values(peerPostureChecks), nil +} + +func (c *Controller) StartWarmup(ctx context.Context) { + var initialInterval int64 + intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") + interval, err := strconv.Atoi(intervalStr) + if err != nil { + initialInterval = 1 + log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) + } else { + initialInterval = int64(interval) * 10 + go func() { + startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") + startupPeriod, err := strconv.Atoi(startupPeriodStr) + if err != nil { + startupPeriod = 1 + log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) + } + time.Sleep(time.Duration(startupPeriod) * time.Second) + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) + }() + } + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) + +} + +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return int64(network_map.OldForwarderPort) + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return int64(network_map.OldForwarderPort) + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return int64(network_map.OldForwarderPort) + } + } + + // All peers have the required version or newer + return int64(network_map.DnsForwarderPort) +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) + if err != nil { + return err + } + + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck := account.GetPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil +} + +// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, sourceGroup := range rule.Sources { + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") + } + + if slices.Contains(group.Peers, peerID) { + return true, nil + } + } + } + + return false, nil +} + +func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { + peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("failed to get peers by ids: %w", err) + } + + for _, peer := range peers { + c.UpdatePeerInNetworkMapCache(accountID, peer) + } + + err = c.bufferSendUpdateAccountPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) + } + + return nil +} + +func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + for _, peerID := range peerIDs { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + } + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountID) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + for _, peerID := range peerIDs { + c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerID) + + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) + continue + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) + continue + } + } + } + + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) +func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + account, err := c.repo.GetAccountByPeerID(ctx, peerID) + if err != nil { + return nil, err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(peer.AccountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + return networkMap, nil +} + +func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { + c.peersUpdateManager.CloseChannels(ctx, peerIDs) +} diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go new file mode 100644 index 000000000..90e7b6e18 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -0,0 +1,109 @@ +package controller + +import ( + "testing" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.DnsForwarderPort) { + t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) + } +} diff --git a/management/internals/controllers/network_map/controller/metrics.go b/management/internals/controllers/network_map/controller/metrics.go new file mode 100644 index 000000000..5832d2130 --- /dev/null +++ b/management/internals/controllers/network_map/controller/metrics.go @@ -0,0 +1,15 @@ +package controller + +import ( + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type metrics struct { + *telemetry.UpdateChannelMetrics +} + +func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) { + return &metrics{ + updateChannelMetrics, + }, nil +} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go new file mode 100644 index 000000000..3ed51a5c3 --- /dev/null +++ b/management/internals/controllers/network_map/controller/repository.go @@ -0,0 +1,49 @@ +package controller + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Repository interface { + GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) + GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) + GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) +} + +type repository struct { + store store.Store +} + +var _ Repository = (*repository)(nil) + +func newRepository(s store.Store) Repository { + return &repository{ + store: s, + } +} + +func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) { + return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) +} + +func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) { + return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { + return r.store.GetAccountByPeerID(ctx, peerID) +} + +func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { + return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs) +} + +func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { + return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go new file mode 100644 index 000000000..b1de7d017 --- /dev/null +++ b/management/internals/controllers/network_map/interface.go @@ -0,0 +1,39 @@ +package network_map + +//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" + + DnsForwarderPort = nbdns.ForwarderServerPort + OldForwarderPort = nbdns.ForwarderClientPort + DnsForwarderPortMinVersion = "v0.59.0" +) + +type Controller interface { + UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string) error + GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + GetDNSDomain(settings *types.Settings) string + StartWarmup(context.Context) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + CountStreams() int + + OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error + OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error + OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error + DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) + OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerID string) +} diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go new file mode 100644 index 000000000..5a98eefa8 --- /dev/null +++ b/management/internals/controllers/network_map/interface_mock.go @@ -0,0 +1,240 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +// + +// Package network_map is a generated GoMock package. +package network_map + +import ( + context "context" + reflect "reflect" + + peer "github.com/netbirdio/netbird/management/server/peer" + posture "github.com/netbirdio/netbird/management/server/posture" + types "github.com/netbirdio/netbird/management/server/types" + gomock "go.uber.org/mock/gomock" +) + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder + isgomock struct{} +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// BufferUpdateAccountPeers mocks base method. +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) +} + +// CountStreams mocks base method. +func (m *MockController) CountStreams() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountStreams") + ret0, _ := ret[0].(int) + return ret0 +} + +// CountStreams indicates an expected call of CountStreams. +func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams)) +} + +// DisconnectPeers mocks base method. +func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs) +} + +// DisconnectPeers indicates an expected call of DisconnectPeers. +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs) +} + +// GetDNSDomain mocks base method. +func (m *MockController) GetDNSDomain(settings *types.Settings) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNSDomain", settings) + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDNSDomain indicates an expected call of GetDNSDomain. +func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings) +} + +// GetNetworkMap mocks base method. +func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID) + ret0, _ := ret[0].(*types.NetworkMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNetworkMap indicates an expected call of GetNetworkMap. +func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID) +} + +// GetValidatedPeerWithMap mocks base method. +func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. +func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) +} + +// OnPeerConnected mocks base method. +func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID) + ret0, _ := ret[0].(chan *UpdateMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OnPeerConnected indicates an expected call of OnPeerConnected. +func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID) +} + +// OnPeerDisconnected mocks base method. +func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID) +} + +// OnPeerDisconnected indicates an expected call of OnPeerDisconnected. +func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID) +} + +// OnPeersAdded mocks base method. +func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeersAdded indicates an expected call of OnPeersAdded. +func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs) +} + +// OnPeersDeleted mocks base method. +func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeersDeleted indicates an expected call of OnPeersDeleted. +func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs) +} + +// OnPeersUpdated mocks base method. +func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeersUpdated indicates an expected call of OnPeersUpdated. +func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs) +} + +// StartWarmup mocks base method. +func (m *MockController) StartWarmup(arg0 context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartWarmup", arg0) +} + +// StartWarmup indicates an expected call of StartWarmup. +func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0) +} + +// UpdateAccountPeer mocks base method. +func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeer indicates an expected call of UpdateAccountPeer. +func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId) +} + +// UpdateAccountPeers mocks base method. +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) +} diff --git a/management/internals/controllers/network_map/network_map.go b/management/internals/controllers/network_map/network_map.go new file mode 100644 index 000000000..e915c2193 --- /dev/null +++ b/management/internals/controllers/network_map/network_map.go @@ -0,0 +1 @@ +package network_map diff --git a/management/internals/controllers/network_map/update_channel.go b/management/internals/controllers/network_map/update_channel.go new file mode 100644 index 000000000..0b085b85f --- /dev/null +++ b/management/internals/controllers/network_map/update_channel.go @@ -0,0 +1,13 @@ +package network_map + +import "context" + +type PeersUpdateManager interface { + SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) + CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage + CloseChannel(ctx context.Context, peerID string) + CountStreams() int + HasChannel(peerID string) bool + CloseChannels(ctx context.Context, peerIDs []string) + GetAllConnectedPeers() map[string]struct{} +} diff --git a/management/server/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go similarity index 87% rename from management/server/updatechannel.go rename to management/internals/controllers/network_map/update_channel/updatechannel.go index da12f1b70..5f7db5300 100644 --- a/management/server/updatechannel.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel.go @@ -1,4 +1,4 @@ -package server +package update_channel import ( "context" @@ -7,38 +7,34 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" ) const channelBufferSize = 100 -type UpdateMessage struct { - Update *proto.SyncResponse - NetworkMap *types.NetworkMap -} - type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage + peerChannels map[string]chan *network_map.UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } +var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil) + // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), + peerChannels: make(map[string]chan *network_map.UpdateMessage), channelsMux: &sync.RWMutex{}, metrics: metrics, } } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) { start := time.Now() var found, dropped bool @@ -66,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage { start := time.Now() closed := false @@ -85,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c close(channel) } // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) + channel := make(chan *network_map.UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) @@ -176,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +func (p *PeersUpdateManager) CountStreams() int { + p.channelsMux.RLock() + defer p.channelsMux.RUnlock() + return len(p.peerChannels) +} diff --git a/management/server/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go similarity index 89% rename from management/server/updatechannel_test.go rename to management/internals/controllers/network_map/update_channel/updatechannel_test.go index 0dc86563d..afc1e2c32 100644 --- a/management/server/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -1,10 +1,11 @@ -package server +package update_channel import ( "context" "testing" "time" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &UpdateMessage{Update: &proto.SyncResponse{ + update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 0, }, @@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &UpdateMessage{Update: &proto.SyncResponse{ + update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 10, }, diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go new file mode 100644 index 000000000..33643bcbd --- /dev/null +++ b/management/internals/controllers/network_map/update_message.go @@ -0,0 +1,9 @@ +package network_map + +import ( + "github.com/netbirdio/netbird/shared/management/proto" +) + +type UpdateMessage struct { + Update *proto.SyncResponse +} diff --git a/management/server/peers/ephemeral/interface.go b/management/internals/modules/peers/ephemeral/interface.go similarity index 83% rename from management/server/peers/ephemeral/interface.go rename to management/internals/modules/peers/ephemeral/interface.go index a1605b3b9..8fe25435c 100644 --- a/management/server/peers/ephemeral/interface.go +++ b/management/internals/modules/peers/ephemeral/interface.go @@ -2,10 +2,15 @@ package ephemeral import ( "context" + "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) +const ( + EphemeralLifeTime = 10 * time.Minute +) + type Manager interface { LoadInitialPeers(ctx context.Context) Stop() diff --git a/management/server/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go similarity index 85% rename from management/server/peers/ephemeral/manager/ephemeral.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral.go index 062ba69d2..15119045b 100644 --- a/management/server/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -7,14 +7,15 @@ import ( log "github.com/sirupsen/logrus" - nbAccount "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" ) const ( - ephemeralLifeTime = 10 * time.Minute // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. cleanupWindow = 1 * time.Minute ) @@ -33,11 +34,11 @@ type ephemeralPeer struct { // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // in worst case we will get invalid error message in this manager. -// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store store.Store - accountManager nbAccount.Manager + store store.Store + peersManager peers.Manager headPeer *ephemeralPeer tailPeer *ephemeralPeer @@ -49,12 +50,12 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager { +func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager { return &EphemeralManager{ - store: store, - accountManager: accountManager, + store: store, + peersManager: peersManager, - lifeTime: ephemeralLifeTime, + lifeTime: ephemeral.EphemeralLifeTime, cleanupWindow: cleanupWindow, } } @@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee } // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer -// is inactive it will be deleted after the ephemeralLifeTime period. +// is inactive it will be deleted after the EphemeralLifeTime period. func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return @@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() - bufferAccountCall := make(map[string]struct{}) - + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { - log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) + peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) + } + + for accountID, peerIDs := range peerIDsPerAccount { + log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) - } else { - bufferAccountCall[p.accountID] = struct{}{} } } - for accountID := range bufferAccountCall { - e.accountManager.BufferUpdateAccountPeers(ctx, accountID) - } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { diff --git a/management/server/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go similarity index 69% rename from management/server/peers/ephemeral/manager/ephemeral_test.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc7525c29..9d3ed246a 100644 --- a/management/server/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -7,10 +7,13 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { @@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers (except the connected one) + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + 1 @@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for the one disconnected peer + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 @@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { testLifeTime = 1 * time.Second testCleanupWindow = 100 * time.Millisecond ) + + t.Cleanup(func() { + timeNow = time.Now + }) + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + mockStore := &MockStore{} + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + + wg := &sync.WaitGroup{} + wg.Add(ephemeralPeers) mockAM := &MockAccountManager{ store: mockStore, + wg: wg, } - mockAM.wg = &sync.WaitGroup{} - mockAM.wg.Add(ephemeralPeers) - mgr := NewEphemeralManager(mockStore, mockAM) + + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) + + // Set up expectation that DeletePeers will be called once with all peer IDs + peersManager.EXPECT(). + DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + // Simulate the actual deletion behavior + for _, peerID := range peerIDs { + err := mockAM.DeletePeer(ctx, accountID, peerID, userID) + if err != nil { + return err + } + } + mockAM.BufferUpdateAccountPeers(ctx, accountID) + return nil + }). + Times(1) + + mgr := NewEphemeralManager(mockStore, peersManager) mgr.lifeTime = testLifeTime mgr.cleanupWindow = testCleanupWindow - account := newAccountWithId(context.Background(), "account", "", "", false) - mockStore.account = account + // Add peers and disconnect them at slightly different times (within cleanup window) for i := range ephemeralPeers { p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} mockStore.account.Peers[p.ID] = p - time.Sleep(testCleanupWindow / ephemeralPeers) mgr.OnPeerDisconnected(context.Background(), p) + startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2)) } - mockAM.wg.Wait() + + // Advance time past the lifetime to trigger cleanup + startTime = startTime.Add(testLifeTime + testCleanupWindow) + + // Wait for all deletions to complete + wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go new file mode 100644 index 000000000..b200b9663 --- /dev/null +++ b/management/internals/modules/peers/manager.go @@ -0,0 +1,164 @@ +package peers + +//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "fmt" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type Manager interface { + GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) + GetPeerAccountID(ctx context.Context, peerID string) (string, error) + GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) + DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error + SetNetworkMapController(networkMapController network_map.Controller) + SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) + SetAccountManager(accountManager account.Manager) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + integratedPeerValidator integrated_validator.IntegratedValidator + accountManager account.Manager + + networkMapController network_map.Controller +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) { + m.networkMapController = networkMapController +} + +func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.integratedPeerValidator = integratedPeerValidator +} + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + m.accountManager = accountManager +} + +func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} + +func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) + } + + return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { + return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) +} + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} + +func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + dnsDomain := m.networkMapController.GetDNSDomain(settings) + + for _, peerID := range peerIDs { + var eventsToStore []func() + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + + if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + return nil + } + + if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil { + return fmt.Errorf("failed to remove peer %s from groups", peerID) + } + + if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil { + return err + } + + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + + if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + + return nil + }) + if err != nil { + return err + } + for _, event := range eventsToStore { + event() + } + } + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} diff --git a/management/server/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go similarity index 55% rename from management/server/peers/manager_mock.go rename to management/internals/modules/peers/manager_mock.go index 994f8346b..2e3651e88 100644 --- a/management/server/peers/manager_mock.go +++ b/management/internals/modules/peers/manager_mock.go @@ -9,6 +9,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map" + account "github.com/netbirdio/netbird/management/server/account" + integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" peer "github.com/netbirdio/netbird/management/server/peer" ) @@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// DeletePeers mocks base method. +func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeers indicates an expected call of DeletePeers. +func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected) +} + // GetAllPeers mocks base method. func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { m.ctrl.T.Helper() @@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) } + +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + +// SetIntegratedPeerValidator mocks base method. +func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator) +} + +// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator. +func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator) +} + +// SetNetworkMapController mocks base method. +func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetworkMapController", networkMapController) +} + +// SetNetworkMapController indicates an expected call of SetNetworkMapController. +func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 16e93a549..57b3fac78 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -10,9 +10,9 @@ import ( "slices" "time" - "github.com/google/uuid" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -22,7 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" @@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { - store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false) + store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) if err != nil { log.Fatalf("failed to create store: %v", err) } @@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store { log.Fatalf("failed to initialize integration metrics: %v", err) } - eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics) + eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } - if s.config.DataStoreEncryptionKey != key { - log.WithContext(context.Background()).Infof("update config with activity store key") - s.config.DataStoreEncryptionKey = key - err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config) + if s.Config.DataStoreEncryptionKey != key { + log.WithContext(context.Background()).Infof("update Config with activity store key") + s.Config.DataStoreEncryptionKey = key + err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config) if err != nil { - log.Fatalf("failed to update config with activity store: %v", err) + log.Fatalf("failed to update Config with activity store: %v", err) } } @@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { - trustedPeers := s.config.ReverseProxy.TrustedPeers + trustedPeers := s.Config.ReverseProxy.TrustedPeers defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") trustedPeers = defaultTrustedPeers } - trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies - trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount + trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies + trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") @@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server { grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } - if s.config.HttpConfig.LetsEncryptDomain != "" { - certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { log.Fatalf("failed to create certificate manager: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.Fatalf("cannot load TLS credentials: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } @@ -180,7 +180,7 @@ func unaryInterceptor( info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { - reqID := uuid.New().String() + reqID := xid.New().String() //nolint ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource) //nolint @@ -194,7 +194,7 @@ func streamInterceptor( info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - reqID := uuid.New().String() + reqID := xid.New().String() wrapped := grpcMiddleware.WrapServerStream(ss) //nolint ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource) diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index ddd81daa2..3442c7646 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,17 +6,21 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) -func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { - return Create(s, func() *server.PeersUpdateManager { - return server.NewPeersUpdateManager(s.Metrics()) +func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { + return Create(s, func() network_map.PeersUpdateManager { + return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -40,26 +44,46 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager { - return Create(s, func() *server.TimeBasedAuthSecretsManager { - return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() grpc.SecretsManager { + return Create(s, func() grpc.SecretsManager { + secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager()) + if err != nil { + log.Fatalf("failed to create secrets manager: %v", err) + } + return secretsManager }) } func (s *BaseServer) AuthManager() auth.Manager { return Create(s, func() auth.Manager { return auth.NewManager(s.Store(), - s.config.HttpConfig.AuthIssuer, - s.config.HttpConfig.AuthAudience, - s.config.HttpConfig.AuthKeysLocation, - s.config.HttpConfig.AuthUserIDClaim, - s.config.GetAuthAudiences(), - s.config.HttpConfig.IdpSignKeyRefreshEnabled) + s.Config.HttpConfig.AuthIssuer, + s.Config.HttpConfig.AuthAudience, + s.Config.HttpConfig.AuthKeysLocation, + s.Config.HttpConfig.AuthUserIDClaim, + s.Config.GetAuthAudiences(), + s.Config.HttpConfig.IdpSignKeyRefreshEnabled) }) } func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.AccountManager()) + return manager.NewEphemeralManager(s.Store(), s.PeersManager()) }) } + +func (s *BaseServer) NetworkMapController() network_map.Controller { + return Create(s, func() network_map.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config) + }) +} + +func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { + return Create(s, func() *server.AccountRequestBuffer { + return server.NewAccountRequestBuffer(context.Background(), s.Store()) + }) +} + +func (s *BaseServer) DNSDomain() string { + return s.dnsDomain +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..af9ca5f2d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -2,10 +2,12 @@ package server import ( "context" + "os" log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/geolocation" @@ -14,20 +16,29 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/users" ) +const ( + geolocationDisabledKey = "NB_DISABLE_GEOLOCATION" +) + func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { + if os.Getenv(geolocationDisabledKey) == "true" { + log.Info("geolocation service is disabled, skipping initialization") + return nil + } + return Create(s, func() geolocation.Geolocation { - geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate) + geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { log.Fatalf("could not initialize geolocation service: %v", err) } - log.Infof("geolocation service has been initialized from %s", s.config.Datadir) + log.Infof("geolocation service has been initialized from %s", s.Config.Datadir) return geo }) @@ -35,7 +46,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } @@ -54,21 +71,22 @@ func (s *BaseServer) SettingsManager() settings.Manager { func (s *BaseServer) PeersManager() peers.Manager { return Create(s, func() peers.Manager { - return peers.NewManager(s.Store(), s.PermissionsManager()) + manager := peers.NewManager(s.Store(), s.PermissionsManager()) + s.AfterInit(func(s *BaseServer) { + manager.SetNetworkMapController(s.NetworkMapController()) + manager.SetIntegratedPeerValidator(s.IntegratedValidator()) + manager.SetAccountManager(s.AccountManager()) + }) + return manager }) } func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, - s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } - - s.AfterInit(func(s *BaseServer) { - accountManager.SetEphemeralManager(s.EphemeralManager()) - }) return accountManager }) } @@ -77,8 +95,8 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - if s.config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics()) + if s.Config.IdpManagerConfig != nil { + idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { log.Fatalf("failed to create IDP manager: %v", err) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ab1c2ebe7..d9c715225 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -41,10 +41,10 @@ type Server interface { } // Server holds the HTTP BaseServer instance. -// Add any additional fields you need, such as database connections, config, etc. +// Add any additional fields you need, such as database connections, Config, etc. type BaseServer struct { - // config holds the server configuration - config *nbconfig.Config + // Config holds the server configuration + Config *nbconfig.Config // container of dependencies, each dependency is identified by a unique string. container map[string]any // AfterInit is a function that will be called after the server is initialized @@ -70,7 +70,7 @@ type BaseServer struct { // NewServer initializes and configures a new Server instance func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer { return &BaseServer{ - config: config, + Config: config, container: make(map[string]any), dnsDomain: dnsDomain, mgmtSingleAccModeDomain: mgmtSingleAccModeDomain, @@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error { var tlsConfig *tls.Config tlsEnabled := false - if s.config.HttpConfig.LetsEncryptDomain != "" { - s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) } tlsEnabled = true - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err) return err @@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error { if !s.disableMetrics { idpManager := "disabled" - if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" { - idpManager = s.config.IdpManagerConfig.ManagerType + if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" { + idpManager = s.Config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager) go metricsWorker.Run(srvCtx) @@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) - s.update = version.NewUpdate("nb/management") + s.update = version.NewUpdateAndStart("nb/management") s.update.SetDaemonVersion(version.NetbirdVersion()) s.update.SetOnUpdateListener(func() { log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion()) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go new file mode 100644 index 000000000..f984c73df --- /dev/null +++ b/management/internals/shared/grpc/conversion.go @@ -0,0 +1,450 @@ +package grpc + +import ( + "context" + "fmt" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/client/ssh/auth" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/sshauth" +) + +func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { + if config == nil { + return nil + } + + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + + var turns []*proto.ProtectedHostConfig + if config.TURNConfig != nil { + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Payload + password = turnCredentials.Signature + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + } + + var relayCfg *proto.RelayConfig + if config.Relay != nil && len(config.Relay.Addresses) > 0 { + relayCfg = &proto.RelayConfig{ + Urls: config.Relay.Addresses, + } + + if relayToken != nil { + relayCfg.TokenPayload = relayToken.Payload + relayCfg.TokenSignature = relayToken.Signature + } + } + + var signalCfg *proto.HostConfig + if config.Signal != nil { + signalCfg = &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + } + } + + nbConfig := &proto.NetbirdConfig{ + Stuns: stuns, + Turns: turns, + Signal: signalCfg, + Relay: relayCfg, + } + + return nbConfig +} + +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + + sshConfig := &proto.SSHConfig{ + SshEnabled: peer.SSHEnabled || enableSSH, + } + + if sshConfig.SshEnabled { + sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) + } + + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), + SshConfig: sshConfig, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, + LazyConnectionEnabled: settings.LazyConnectionEnabled, + AutoUpdate: &proto.AutoUpdateSettings{ + Version: settings.AutoUpdateVersion, + }, + } +} + +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + response := &proto.SyncResponse{ + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), + }, + Checks: toProtocolChecks(ctx, checks), + } + + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + response.NetbirdConfig = extendedConfig + + response.NetworkMap.PeerConfig = response.PeerConfig + + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 + response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty + + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + response.NetworkMap.FirewallRules = firewallRules + response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + + if networkMap.ForwardingRules != nil { + forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) + for _, rule := range networkMap.ForwardingRules { + forwardingRules = append(forwardingRules, rule.ToProto()) + } + response.NetworkMap.ForwardingRules = forwardingRules + } + + if networkMap.AuthorizedUsers != nil { + hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers) + userIDClaim := auth.DefaultUserIDClaim + if httpConfig != nil && httpConfig.AuthUserIDClaim != "" { + userIDClaim = httpConfig.AuthUserIDClaim + } + response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim} + } + + return response +} + +func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { + userIDToIndex := make(map[string]uint32) + var hashedUsers [][]byte + machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) + + for machineUser, users := range authorizedUsers { + indexes := make([]uint32, 0, len(users)) + for userID := range users { + idx, exists := userIDToIndex[userID] + if !exists { + hash, err := sshauth.HashUserID(userID) + if err != nil { + log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) + continue + } + idx = uint32(len(hashedUsers)) + userIDToIndex[userID] = idx + hashedUsers = append(hashedUsers, hash[:]) + } + indexes = append(indexes, idx) + } + machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes} + } + + return hashedUsers, machineUsers +} + +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{rPeer.IP.String() + "/32"}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + AgentVersion: rPeer.Meta.WtVersion, + }) + } + return dst +} + +// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache +func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, + } + + for _, zone := range update.CustomZones { + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + } + + for _, nsGroup := range update.NameServerGroups { + cacheKey := nsGroup.ID + if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + } else { + protoGroup := convertToProtoNameServerGroup(nsGroup) + cache.SetNameServerGroup(cacheKey, protoGroup) + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) + } + } + + return protoUpdate +} + +func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { + switch configProto { + case nbconfig.UDP: + return proto.HostConfig_UDP + case nbconfig.DTLS: + return proto.HostConfig_DTLS + case nbconfig.HTTP: + return proto.HostConfig_HTTP + case nbconfig.HTTPS: + return proto.HostConfig_HTTPS + case nbconfig.TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toProtocolRoutes(routes []*route.Route) []*proto.Route { + protoRoutes := make([]*proto.Route, 0, len(routes)) + for _, r := range routes { + protoRoutes = append(protoRoutes, toProtocolRoute(r)) + } + return protoRoutes +} + +func toProtocolRoute(route *route.Route) *proto.Route { + return &proto.Route{ + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, + } +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + fwRule := &proto.FirewallRule{ + PolicyID: []byte(rule.PolicyID), + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, + } + + if shouldUsePortRange(fwRule) { + fwRule.PortInfo = rule.PortRange.ToProto() + } + + result[i] = fwRule + } + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == types.FirewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), + PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), + } + } + + return result +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(types.PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case types.PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case types.PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case types.PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} + +func shouldUsePortRange(rule *proto.FirewallRule) bool { + return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) +} + +// Helper function to convert nbdns.CustomZone to proto.CustomZone +func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + } + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, + }) + } + return protoZone +} + +// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup +func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup +} + +// buildJWTConfig constructs JWT configuration for SSH servers from management server config +func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig { + if config == nil || config.AuthAudience == "" { + return nil + } + + issuer := strings.TrimSpace(config.AuthIssuer) + if issuer == "" && deviceFlowConfig != nil { + if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" { + issuer = d + } + } + if issuer == "" { + return nil + } + + keysLocation := strings.TrimSpace(config.AuthKeysLocation) + if keysLocation == "" { + keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" + } + + return &proto.JWTConfig{ + Issuer: issuer, + Audience: config.AuthAudience, + KeysLocation: keysLocation, + } +} + +// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint +func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string { + if tokenEndpoint == "" { + return "" + } + + u, err := url.Parse(tokenEndpoint) + if err != nil { + return "" + } + + return fmt.Sprintf("%s://%s/", u.Scheme, u.Host) +} diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go new file mode 100644 index 000000000..701271345 --- /dev/null +++ b/management/internals/shared/grpc/conversion_test.go @@ -0,0 +1,150 @@ +package grpc + +import ( + "fmt" + "net/netip" + "reflect" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" +) + +func TestToProtocolDNSConfigWithCache(t *testing.T) { + var cache cache.DNSConfigCache + + // Create two different configs + config1 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.com", + Records: []nbdns.SimpleRecord{ + {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group1", + Name: "Group 1", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, + }, + }, + }, + } + + config2 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.org", + Records: []nbdns.SimpleRecord{ + {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group2", + Name: "Group 2", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, + }, + }, + }, + } + + // First run with config1 + result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Second run with config2 + result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) + + // Third run with config1 again + result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Verify that result1 and result3 are identical + if !reflect.DeepEqual(result1, result3) { + t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) + } + + // Verify that result2 is different from result1 and result3 + if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { + t.Errorf("Results should be different for different inputs") + } + + if _, exists := cache.GetNameServerGroup("group1"); !exists { + t.Errorf("Cache should contain name server group 'group1'") + } + + if _, exists := cache.GetNameServerGroup("group2"); !exists { + t.Errorf("Cache should contain name server group 'group2'") + } +} + +func BenchmarkToProtocolDNSConfig(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + testData := generateTestData(size) + + b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { + cache := &cache.DNSConfigCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + + b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := &cache.DNSConfigCache{} + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + } +} + +func generateTestData(size int) nbdns.Config { + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: make([]nbdns.CustomZone, size), + NameServerGroups: make([]*nbdns.NameServerGroup, size), + } + + for i := 0; i < size; i++ { + config.CustomZones[i] = nbdns.CustomZone{ + Domain: fmt.Sprintf("domain%d.com", i), + Records: []nbdns.SimpleRecord{ + { + Name: fmt.Sprintf("record%d", i), + Type: 1, + Class: "IN", + TTL: 3600, + RData: "192.168.1.1", + }, + }, + } + + config.NameServerGroups[i] = &nbdns.NameServerGroup{ + ID: fmt.Sprintf("group%d", i), + Primary: i == 0, + Domains: []string{fmt.Sprintf("domain%d.com", i)}, + SearchDomainsEnabled: true, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + Port: 53, + NSType: 1, + }, + }, + } + } + + return config +} diff --git a/management/server/loginfilter.go b/management/internals/shared/grpc/loginfilter.go similarity index 99% rename from management/server/loginfilter.go rename to management/internals/shared/grpc/loginfilter.go index 8604af6e2..59f69dd90 100644 --- a/management/server/loginfilter.go +++ b/management/internals/shared/grpc/loginfilter.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go similarity index 99% rename from management/server/loginfilter_test.go rename to management/internals/shared/grpc/loginfilter_test.go index 65782dd9d..8b26e14ab 100644 --- a/management/server/loginfilter_test.go +++ b/management/internals/shared/grpc/loginfilter_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/grpcserver.go b/management/internals/shared/grpc/server.go similarity index 69% rename from management/server/grpcserver.go rename to management/internals/shared/grpc/server.go index 12b59b691..ad6b34c5f 100644 --- a/management/server/grpcserver.go +++ b/management/internals/shared/grpc/server.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -7,8 +7,10 @@ import ( "net" "net/netip" "os" + "strconv" "strings" "sync" + "sync/atomic" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -20,9 +22,8 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -44,49 +45,49 @@ import ( const ( envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" envBlockPeers = "NB_BLOCK_SAME_PEERS" + envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS" + + defaultSyncLim = 1000 ) -// GRPCServer an instance of a Management gRPC API server -type GRPCServer struct { +// Server an instance of a Management gRPC API server +type Server struct { accountManager account.Manager settingsManager settings.Manager - wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager ephemeral.Manager - peerLocks sync.Map - authManager auth.Manager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + peerLocks sync.Map + authManager auth.Manager logBlockedPeers bool blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + + loginFilter *loginFilter + + networkMapController network_map.Controller + + syncSem atomic.Int32 + syncLim int32 } // NewServer creates a new Management server func NewServer( - ctx context.Context, config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, -) (*GRPCServer, error) { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, err - } - + networkMapController network_map.Controller, +) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams - err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(len(peersUpdateManager.peerChannels)) + err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(networkMapController.CountStreams()) }) if err != nil { return nil, err @@ -96,24 +97,36 @@ func NewServer( logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" - return &GRPCServer{ - wgKey: key, - // peerKey -> event channel - peersUpdateManager: peersUpdateManager, + syncLim := int32(defaultSyncLim) + if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { + syncLimParsed, err := strconv.Atoi(syncLimStr) + if err != nil { + log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim) + } else { + //nolint:gosec + syncLim = int32(syncLimParsed) + } + } + + return &Server{ accountManager: accountManager, settingsManager: settingsManager, config: config, secretsManager: secretsManager, authManager: authManager, appMetrics: appMetrics, - ephemeralManager: ephemeralManager, logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + networkMapController: networkMapController, + + loginFilter: newLoginFilter(), + + syncLim: syncLim, }, nil } -func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { +func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { ip := "" p, ok := peer.FromContext(ctx) if ok { @@ -121,10 +134,6 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto } log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start)) - }() // todo introduce something more meaningful with the key expiration/rotation if s.appMetrics != nil { @@ -135,8 +144,14 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto nanos := int32(now.Nanosecond()) expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err) + return nil, errors.New("failed to get wireguard key") + } + return &proto.ServerKeyResponse{ - Key: s.wgKey.PublicKey().String(), + Key: key.PublicKey().String(), ExpiresAt: expiresAt, }, nil } @@ -150,7 +165,12 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) -func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + if s.syncSem.Load() >= s.syncLim { + return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") + } + s.syncSem.Add(1) + reqStart := time.Now() ctx := srv.Context() @@ -158,20 +178,22 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { + s.syncSem.Add(-1) return err } realIP := getRealIP(ctx) sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { + s.syncSem.Add(-1) return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) } } @@ -183,48 +205,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) - unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) - defer func() { - if unlock != nil { - unlock() - } - }() - accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + s.syncSem.Add(-1) return status.Errorf(codes.PermissionDenied, "peer is not registered") } + s.syncSem.Add(-1) return err } // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + start := time.Now() + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + metahash := metaHash(peerMeta, realIP.String()) + s.loginFilter.addLogin(peerKey.String(), metahash) + + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return mapError(ctx, err) } - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return err } - updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - - s.ephemeralManager.OnPeerConnected(ctx, peer) + updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) + return err + } s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) @@ -235,13 +270,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi unlock() unlock = nil - log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + s.syncSem.Add(-1) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -275,14 +310,20 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { + key, err := s.secretsManager.GetWGKey() + if err != nil { + s.cancelPeerRoutines(ctx, accountID, peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) if err != nil { @@ -293,7 +334,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() @@ -301,14 +342,13 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } - s.peersUpdateManager.CloseChannel(ctx, peer.ID) + s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - s.ephemeralManager.OnPeerDisconnected(ctx, peer) - log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) + log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } -func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -342,7 +382,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return userAuth.UserId, nil } -func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { +func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) start := time.Now() @@ -450,14 +490,19 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee } } -func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, parsed) if err != nil { return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") } @@ -469,11 +514,10 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it is, the login is successful // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful -func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -483,9 +527,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() @@ -509,6 +553,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) @@ -546,30 +592,31 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, mapError(ctx, err) } - // if the login request contains setup key then it is a registration request - if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(ctx, peer) - } - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) return nil, status.Errorf(codes.Internal, "failed logging in peer") } - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } -func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { +func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { var relayToken *Token var err error if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { @@ -588,7 +635,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH), Checks: toProtocolChecks(ctx, postureChecks), } @@ -600,7 +647,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // // The user ID can be empty if the token is not provided, which is acceptable if the peer is already // registered or if it uses a setup key to register. -func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { +func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { userID := "" if loginReq.GetJwtToken() != "" { var err error @@ -620,166 +667,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR return userID, nil } -func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { - switch configProto { - case nbconfig.UDP: - return proto.HostConfig_UDP - case nbconfig.DTLS: - return proto.HostConfig_DTLS - case nbconfig.HTTP: - return proto.HostConfig_HTTP - case nbconfig.HTTPS: - return proto.HostConfig_HTTPS - case nbconfig.TCP: - return proto.HostConfig_TCP - default: - panic(fmt.Errorf("unexpected config protocol type %v", configProto)) - } -} - -func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { - if config == nil { - return nil - } - - var stuns []*proto.HostConfig - for _, stun := range config.Stuns { - stuns = append(stuns, &proto.HostConfig{ - Uri: stun.URI, - Protocol: ToResponseProto(stun.Proto), - }) - } - - var turns []*proto.ProtectedHostConfig - if config.TURNConfig != nil { - for _, turn := range config.TURNConfig.Turns { - var username string - var password string - if turnCredentials != nil { - username = turnCredentials.Payload - password = turnCredentials.Signature - } else { - username = turn.Username - password = turn.Password - } - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: turn.URI, - Protocol: ToResponseProto(turn.Proto), - }, - User: username, - Password: password, - }) - } - } - - var relayCfg *proto.RelayConfig - if config.Relay != nil && len(config.Relay.Addresses) > 0 { - relayCfg = &proto.RelayConfig{ - Urls: config.Relay.Addresses, - } - - if relayToken != nil { - relayCfg.TokenPayload = relayToken.Payload - relayCfg.TokenSignature = relayToken.Signature - } - } - - var signalCfg *proto.HostConfig - if config.Signal != nil { - signalCfg = &proto.HostConfig{ - Uri: config.Signal.URI, - Protocol: ToResponseProto(config.Signal.Proto), - } - } - - nbConfig := &proto.NetbirdConfig{ - Stuns: stuns, - Turns: turns, - Signal: signalCfg, - Relay: relayCfg, - } - - return nbConfig -} - -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { - netmask, _ := network.Net.Mask.Size() - fqdn := peer.FQDN(dnsName) - return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, - RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, - LazyConnectionEnabled: settings.LazyConnectionEnabled, - } -} - -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), - NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), - }, - Checks: toProtocolChecks(ctx, checks), - } - - nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) - extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) - response.NetbirdConfig = extendedConfig - - response.NetworkMap.PeerConfig = response.PeerConfig - - remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) - response.RemotePeers = remotePeers - response.NetworkMap.RemotePeers = remotePeers - response.RemotePeersIsEmpty = len(remotePeers) == 0 - response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) - - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) - response.NetworkMap.FirewallRules = firewallRules - response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 - - routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) - response.NetworkMap.RoutesFirewallRules = routesFirewallRules - response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 - - if networkMap.ForwardingRules != nil { - forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) - for _, rule := range networkMap.ForwardingRules { - forwardingRules = append(forwardingRules, rule.ToProto()) - } - response.NetworkMap.ForwardingRules = forwardingRules - } - - return response -} - -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - for _, rPeer := range peers { - dst = append(dst, &proto.RemotePeerConfig{ - WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, - SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, - Fqdn: rPeer.FQDN(dnsName), - AgentVersion: rPeer.Meta.WtVersion, - }) - } - return dst -} - // IsHealthy indicates whether the service is healthy -func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { +func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error { var err error var turnToken *Token @@ -803,27 +697,25 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID) + peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID) if err != nil { return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - // Get all peers in the account for forwarder port computation - allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") + plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + + key, err := s.secretsManager.GetWGKey() if err != nil { - return fmt.Errorf("get account peers: %w", err) + return status.Errorf(codes.Internal, "failed getting server key") } - dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) - - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp) if err != nil { return status.Errorf(codes.Internal, "error handling request") } err = srv.Send(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) @@ -838,12 +730,8 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p // GetDeviceAuthorizationFlow returns a device authorization flow information // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { @@ -852,7 +740,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -882,13 +775,13 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. }, } - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -896,12 +789,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { @@ -910,7 +799,12 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -938,20 +832,20 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } // SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, // peer's under the same account of any updates. -func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) @@ -976,7 +870,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } -func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) start := time.Now() diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go new file mode 100644 index 000000000..d3a12e986 --- /dev/null +++ b/management/internals/shared/grpc/server_test.go @@ -0,0 +1,108 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { + testingServerKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testingClientKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testCases := []struct { + name string + inputFlow *config.DeviceAuthorizationFlow + expectedFlow *mgmtProto.DeviceAuthorizationFlow + expectedErrFunc require.ErrorAssertionFunc + expectedErrMSG string + expectedComparisonFunc require.ComparisonAssertionFunc + expectedComparisonMSG string + }{ + { + name: "Testing No Device Flow Config", + inputFlow: nil, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Invalid Device Flow Provider Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "NoNe", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Full Device Flow Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "hosted", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ + Provider: 0, + ProviderConfig: &mgmtProto.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.NoError, + expectedErrMSG: "should not return error", + expectedComparisonFunc: require.Equal, + expectedComparisonMSG: "should match", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mgmtServer := &Server{ + secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey}, + config: &config.Config{ + DeviceAuthorizationFlow: testCase.inputFlow, + }, + } + + message := &mgmtProto.DeviceAuthorizationFlowRequest{} + key, err := mgmtServer.secretsManager.GetWGKey() + require.NoError(t, err, "should be able to get server key") + + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message) + require.NoError(t, err, "should be able to encrypt message") + + resp, err := mgmtServer.GetDeviceAuthorizationFlow( + context.TODO(), + &mgmtProto.EncryptedMessage{ + WgPubKey: testingClientKey.PublicKey().String(), + Body: encryptedMSG, + }, + ) + testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) + if testCase.expectedComparisonFunc != nil { + flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} + + err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + require.NoError(t, err, "should be able to decrypt") + + testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) + testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) + } + }) + } +} diff --git a/management/server/token_mgr.go b/management/internals/shared/grpc/token_mgr.go similarity index 87% rename from management/server/token_mgr.go rename to management/internals/shared/grpc/token_mgr.go index f9293e7a8..ccb32202f 100644 --- a/management/server/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -10,8 +10,10 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -28,6 +30,7 @@ type SecretsManager interface { GenerateRelayToken() (*Token, error) SetupRefresh(ctx context.Context, accountID, peerKey string) CancelRefresh(peerKey string) + GetWGKey() (wgtypes.Key, error) } // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server @@ -37,16 +40,22 @@ type TimeBasedAuthSecretsManager struct { relayCfg *nbconfig.Relay turnHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator - updateManager *PeersUpdateManager + updateManager network_map.PeersUpdateManager settingsManager settings.Manager groupsManager groups.Manager turnCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{} + wgKey wgtypes.Key } type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -55,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg * relayCancelMap: make(map[string]chan struct{}), settingsManager: settingsManager, groupsManager: groupsManager, + wgKey: key, } if turnCfg != nil { @@ -80,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg * } } - return mgr + return mgr, nil +} + +// GetWGKey returns WireGuard private key used to generate peer keys +func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) { + return m.wgKey, nil } // GenerateTurnToken generates new time-based secret credentials for TURN @@ -152,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI relayCancel := make(chan struct{}, 1) m.relayCancelMap[peerID] = relayCancel go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel) - log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID) } } @@ -163,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID) return case <-ticker.C: m.pushNewTURNAndRelayTokens(ctx, accountID, peerID) @@ -178,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID) return case <-ticker.C: m.pushNewRelayTokens(ctx, accountID, peerID) @@ -227,7 +242,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -251,7 +266,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/server/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go similarity index 90% rename from management/server/token_mgr_test.go rename to management/internals/shared/grpc/token_mgr_test.go index 5c956dc31..98eb66fb5 100644 --- a/management/server/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -13,6 +13,8 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) rc := &config.Relay{ Addresses: []string{"localhost:0"}, @@ -44,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) turnCredentials, err := tested.GenerateTurnToken() require.NoError(t, err) @@ -80,7 +83,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { ttl := util.Duration{Duration: 2 * time.Second} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" updateChannel := peersManager.CreateChannel(context.Background(), peer) @@ -96,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -116,7 +120,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { t.Errorf("expecting peer to be present in the relay cancel map, got not present") } - var updates []*UpdateMessage + var updates []*network_map.UpdateMessage loop: for timeout := time.After(5 * time.Second); ; { @@ -185,7 +189,7 @@ loop: func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" rc := &config.Relay{ @@ -199,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) tested.SetupRefresh(context.Background(), "someAccountID", peer) if _, ok := tested.turnCancelMap[peer]; !ok { diff --git a/management/main.go b/management/main.go index 561ed8f26..ff8482f97 100644 --- a/management/main.go +++ b/management/main.go @@ -1,11 +1,19 @@ package main import ( - "github.com/netbirdio/netbird/management/cmd" + "log" + "net/http" + // nolint:gosec + _ "net/http/pprof" "os" + + "github.com/netbirdio/netbird/management/cmd" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/management/server/account.go b/management/server/account.go index dca105ddf..405a3c0f6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,12 +11,12 @@ import ( "reflect" "regexp" "slices" - "strconv" "strings" "sync" - "sync/atomic" "time" + "github.com/netbirdio/netbird/shared/auth" + cacheStore "github.com/eko/gocache/lib/v4/store" "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" @@ -26,6 +26,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -35,7 +37,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -68,28 +69,29 @@ type DefaultAccountManager struct { cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded cacheLoading map[string]chan struct{} - peersUpdateManager *PeersUpdateManager + networkMapController network_map.Controller idpManager idp.Manager cacheManager *nbcache.AccountUserDataCache externalCacheManager nbcache.UserDataCache ctx context.Context eventStore activity.Store geo geolocation.Geolocation - ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer proxyController port_forwarding.Controller settingsManager settings.Manager + // config contains the management server configuration + config *nbconfig.Config + // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. singleAccountMode bool // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string - // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + peerLoginExpiry Scheduler peerInactivityExpiry Scheduler @@ -103,14 +105,11 @@ type DefaultAccountManager struct { permissionsManager permissions.Manager - accountUpdateLocks sync.Map - updateAccountPeersBufferInterval atomic.Int64 - - loginFilter *loginFilter - disableDefaultPolicy bool } +var _ account.Manager = (*DefaultAccountManager)(nil) + func isUniqueConstraintError(err error) bool { switch { case strings.Contains(err.Error(), "(SQLSTATE 23505)"), @@ -176,11 +175,11 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, + config *nbconfig.Config, store store.Store, - peersUpdateManager *PeersUpdateManager, + networkMapController network_map.Controller, idpManager idp.Manager, singleAccountModeDomain string, - dnsDomain string, eventStore activity.Store, geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, @@ -198,13 +197,13 @@ func BuildManager( am := &DefaultAccountManager{ Store: store, + config: config, geo: geo, - peersUpdateManager: peersUpdateManager, + networkMapController: networkMapController, idpManager: idpManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(), @@ -215,11 +214,10 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, - loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, } - am.startWarmup(ctx) + am.networkMapController.StartWarmup(ctx) accountsCounter, err := store.GetAccountsCounter(ctx) if err != nil { @@ -238,7 +236,7 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) if err != nil { return nil, fmt.Errorf("getting cache store: %s", err) } @@ -263,36 +261,6 @@ func BuildManager( return am, nil } -func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { - am.ephemeralManager = em -} - -func (am *DefaultAccountManager) startWarmup(ctx context.Context) { - var initialInterval int64 - intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") - interval, err := strconv.Atoi(intervalStr) - if err != nil { - initialInterval = 1 - log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) - } else { - initialInterval = int64(interval) * 10 - go func() { - startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") - startupPeriod, err := strconv.Atoi(startupPeriodStr) - if err != nil { - startupPeriod = 1 - log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) - } - time.Sleep(time.Duration(startupPeriod) * time.Second) - am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) - }() - } - am.updateAccountPeersBufferInterval.Store(initialInterval) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) - -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -327,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } - if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil { return err } + if oldSettings.Extra != nil && newSettings.Extra != nil && + oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { + approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to approve pending peers: %w", err) + } + + if approvedCount > 0 { + log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID) + updateAccountPeers = true + } + } + if oldSettings.NetworkRange != newSettings.NetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err @@ -340,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || - oldSettings.DNSDomain != newSettings.DNSDomain { + oldSettings.DNSDomain != newSettings.DNSDomain || + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { updateAccountPeers = true } @@ -351,6 +333,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } + newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups + newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator + if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { return err } @@ -376,6 +361,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } @@ -401,7 +387,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } -func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -415,17 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") - if err != nil { - return err - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { @@ -477,6 +453,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{ + "version": newSettings.AutoUpdateVersion, + }) + } +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -816,6 +800,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) + if ctx == nil { + ctx = context.Background() + } + + // nolint:staticcheck + ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) if err != nil { return nil, nil, err @@ -1040,7 +1031,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth, primaryDomain bool, ) error { if userAuth.Domain == "" { @@ -1089,7 +1080,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - userAuth nbcontext.UserAuth, + userAuth auth.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) @@ -1108,7 +1099,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } @@ -1139,7 +1130,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID @@ -1251,7 +1242,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { - log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) return nil, err } @@ -1303,7 +1294,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac return newOnboarding, nil } -func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } @@ -1347,7 +1338,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager -func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { if userAuth.IsChild || userAuth.IsPAT { return nil } @@ -1465,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } - if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) - if err != nil { - return err - } + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) + if err != nil { + return err + } - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) - if err != nil { - return err - } + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) + if err != nil { + return err + } - if removedGroupAffectsPeers || newGroupsAffectsPeers { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) - } + if removedGroupAffectsPeers || newGroupsAffectsPeers { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } return nil @@ -1505,7 +1494,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // // UserAuth IsChild -> checks that account exists -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) @@ -1584,7 +1573,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) { userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) @@ -1632,23 +1621,14 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { +func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool { return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { - return am.loginFilter.allowLogin(wgPubKey, metahash) -} - -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) - }() - - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { - return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) + return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) @@ -1656,10 +1636,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - metahash := metaHash(meta, realIP.String()) - am.loginFilter.addLogin(peerPubKey, metahash) - - return peer, netMap, postureChecks, nil + return peer, netMap, postureChecks, dnsfwdPort, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { @@ -1676,41 +1653,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) + _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { - return mapError(ctx, err) + return err } return nil } -// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() -func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { - return am.peersUpdateManager.GetAllConnectedPeers(), nil -} - -// HasConnectedChannel returns true if peers has channel in update manager, otherwise false -func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { - return am.peersUpdateManager.HasChannel(peerID) -} - var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { return invalidDomainRegexp.MatchString(domain) } -// GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { - if settings == nil { - return am.dnsDomain - } - if settings.DNSDomain == "" { - return am.dnsDomain - } - - return settings.DNSDomain -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { peers := []*nbpeer.Peer{} log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID) @@ -2129,7 +2084,14 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us } if updateNetworkMap { - am.BufferUpdateAccountPeers(ctx, accountID) + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil } @@ -2177,7 +2139,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti if err != nil { return fmt.Errorf("get account settings: %w", err) } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..b5921ec7a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -6,13 +6,13 @@ import ( "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -45,10 +45,10 @@ type Manager interface { GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) - GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) - GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -89,7 +89,6 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) @@ -97,10 +96,8 @@ type Manager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) - LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - GetAllConnectedPeers() (map[string]struct{}, error) - HasConnectedChannel(peerID string) bool + LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) @@ -109,8 +106,8 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -120,12 +117,10 @@ type Manager interface { UpdateAccountPeers(ctx context.Context, accountID string) BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error + SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) - SetEphemeralManager(em ephemeral.Manager) - AllowSync(string, uint64) bool + GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) } diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go new file mode 100644 index 000000000..eced1929f --- /dev/null +++ b/management/server/account/request_buffer.go @@ -0,0 +1,11 @@ +package account + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/types" +) + +type RequestBuffer interface { + GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 07d2f2383..25818ada2 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -22,10 +22,15 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -42,6 +47,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" ) func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) { @@ -391,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -406,7 +412,7 @@ func TestNewAccount(t *testing.T) { } func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -442,7 +448,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams nbcontext.UserAuth + type initUserParams auth.UserAuth var ( publicDomain = "public.com" @@ -465,7 +471,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims nbcontext.UserAuth + inputClaims auth.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -480,7 +486,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -497,7 +503,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -514,7 +520,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -531,7 +537,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -549,7 +555,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -566,7 +572,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -584,7 +590,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -603,7 +609,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) @@ -613,7 +619,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -644,7 +650,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" _ = newAccountWithId(context.Background(), "", userId, domain, false) - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") @@ -653,7 +659,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, @@ -705,7 +711,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { } func TestAccountManager_PrivateAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -731,7 +737,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } func TestAccountManager_SetOrUpdateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -768,7 +774,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } func TestAccountManager_GetAccountByUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +811,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string) } func TestAccountManager_GetAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -843,7 +849,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } func TestAccountManager_DeleteAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -912,19 +918,19 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := nbcontext.UserAuth{ + publicClaims := auth.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, } - am, err := createManager(b) + am, _, err := createManager(b) if err != nil { b.Fatal(err) return @@ -1016,7 +1022,7 @@ func genUsers(p string, n int) map[string]*types.User { } func TestAccountManager_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1086,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) { } func TestAccountManager_AddPeerWithUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1154,8 +1160,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } +func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SaveGroup(t) +} + func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + testAccountManager_NetworkUpdates_SaveGroup(t) +} + +func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1181,8 +1196,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1205,11 +1220,20 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { - manager, account, peer1, _, _ := setupNetworkMapTest(t) +func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePolicy(t) +} - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + +func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) // Ensure that we do not receive an update message before the policy is deleted time.Sleep(time.Second) @@ -1239,8 +1263,17 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SavePolicy(t) +} + func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { - manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + testAccountManager_NetworkUpdates_SavePolicy(t) +} + +func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ AccountID: account.Id, @@ -1253,8 +1286,8 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1288,8 +1321,17 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePeer(t) +} + func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { - manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + testAccountManager_NetworkUpdates_DeletePeer(t) +} + +func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1318,8 +1360,11 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + // We need to sleep to wait for the buffer peer update + time.Sleep(300 * time.Millisecond) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1341,11 +1386,20 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) +func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeleteGroup(t) +} - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + +func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1377,6 +1431,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } + for drained := false; !drained; { + select { + case <-updMsg: + default: + drained = true + } + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1404,7 +1466,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func TestAccountManager_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1485,7 +1547,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy } func TestGetUsersFromAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1736,7 +1798,9 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, + NetworkMapCache: &types.NetworkMapBuilder{}, } + account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) @@ -1782,7 +1846,7 @@ func hasNilField(x interface{}) error { } func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1797,7 +1861,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1853,7 +1917,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { } func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1896,7 +1960,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1958,7 +2022,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test } func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1994,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + accountID := account.Id + userID := account.Users[account.CreatedBy].Id + ctx := context.Background() + + newSettings := account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: true, + } + _, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + peer1.Status.RequiresApproval = true + peer2.Status.RequiresApproval = true + peer3.Status.RequiresApproval = false + + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3)) + + newSettings = account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: false, + } + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range accountPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID) + } +} + func TestAccount_GetExpiredPeers(t *testing.T) { type test struct { name string @@ -2622,7 +2723,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", "postgres") - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") // create a new account @@ -2648,7 +2749,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group3"}, @@ -2663,7 +2764,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2677,7 +2778,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2698,7 +2799,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"])) - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2716,7 +2817,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1", "group2"}, @@ -2730,7 +2831,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group2"}, @@ -2744,7 +2845,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{"group1", "group3"}, @@ -2762,7 +2863,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2777,7 +2878,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{}, @@ -2864,18 +2965,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { // Fatalf(format string, args ...interface{}) // } -func createManager(t testing.TB) (*DefaultAccountManager, error) { +func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - return nil, err + return nil, nil, err } ctrl := gomock.NewController(t) @@ -2893,12 +2994,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { - return nil, err + return nil, nil, err } - return manager, nil + return manager, updateManager, nil } func createStore(t testing.TB) (store.Store, error) { @@ -2927,10 +3033,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() - manager, err := createManager(t) + manager, updateManager, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2971,10 +3077,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, peer2 := getPeer(manager, setupKey) peer3 := getPeer(manager, setupKey) - return manager, account, peer1, peer2, peer3 + return manager, updateManager, account, peer1, peer2, peer3 } -func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -2984,7 +3090,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag } } -func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3022,7 +3128,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3031,16 +3137,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) assert.NoError(b, err) } @@ -3085,7 +3189,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3094,11 +3198,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3155,7 +3258,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3164,11 +3267,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3227,7 +3329,7 @@ func TestMain(m *testing.M) { } func Test_GetCreateAccountByPrivateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3273,7 +3375,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) { } func Test_UpdateToPrimaryAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3303,12 +3405,12 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { } func TestDefaultAccountManager_IsCacheCold(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) t.Run("memory cache", func(t *testing.T) { t.Run("should always return true", func(t *testing.T) { - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) cold, err := manager.isCacheCold(context.Background(), cacheStore) @@ -3323,7 +3425,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) t.Run("should return true when no account exists", func(t *testing.T) { @@ -3353,7 +3455,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { } func TestPropagateUserGroupMemberships(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) ctx := context.Background() @@ -3470,7 +3572,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3502,7 +3604,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3541,7 +3643,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -3608,7 +3710,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { } func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3630,7 +3732,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { // Test adding new user to existing account with approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3654,13 +3756,13 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { } func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } // Create a domain-based account without user approval - ownerUserAuth := nbcontext.UserAuth{ + ownerUserAuth := auth.UserAuth{ UserId: "owner-user", Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3679,7 +3781,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { // Test adding new user to existing account without approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 5c5989f84..6344b2904 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -179,6 +179,9 @@ const ( PeerIPUpdated Activity = 88 UserApproved Activity = 89 UserRejected Activity = 90 + UserCreated Activity = 91 + + AccountAutoUpdateVersionUpdated Activity = 92 AccountDeleted Activity = 99999 ) @@ -286,8 +289,12 @@ var activityMap = map[Activity]Code{ AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, - UserApproved: {"User approved", "user.approve"}, - UserRejected: {"User rejected", "user.reject"}, + + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, + UserCreated: {"User created", "user.create"}, + + AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b165938..ffecb6b8f 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strconv" + "time" log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" @@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e return nil, err } - if storeEngine == types.SqliteStoreEngine { - sqlDB.SetMaxOpenConns(1) - } else { - conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) - if err != nil { - conns = runtime.NumCPU() - } - sqlDB.SetMaxOpenConns(conns) + conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) + if err != nil { + conns = runtime.NumCPU() } + if storeEngine == types.SqliteStoreEngine { + conns = 1 + } + + sqlDB.SetMaxOpenConns(conns) + sqlDB.SetMaxIdleConns(conns) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(3 * time.Minute) + + log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) return db, nil } diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index ece9dc321..0c62357dc 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -9,18 +9,19 @@ import ( "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/base62" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) var _ Manager = (*manager)(nil) type Manager interface { - ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsed(ctx context.Context, tokenID string) error GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } @@ -55,20 +56,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s } } -func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { token, err := m.validator.ValidateAndParse(ctx, value) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } userAuth, err := m.extractor.ToUserAuth(token) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } return userAuth, token, err } -func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index 30a7a7161..edf158a49 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,9 +3,10 @@ package auth import ( "context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/golang-jwt/jwt/v5" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" ) @@ -15,18 +16,18 @@ var ( // @note really dislike this mocking approach but rather than have to do additional test refactoring. type MockManager struct { - ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsedFunc func(ctx context.Context, tokenID string) error GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } // EnsureUserAccessByJWTGroups implements Manager. -func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if m.EnsureUserAccessByJWTGroupsFunc != nil { return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) } - return nbcontext.UserAuth{}, nil + return auth.UserAuth{}, nil } // GetPATInfo implements Manager. @@ -46,9 +47,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { } // ValidateAndParseToken implements Manager. -func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { if m.ValidateAndParseTokenFunc != nil { return m.ValidateAndParseTokenFunc(ctx, value) } - return nbcontext.UserAuth{}, &jwt.Token{}, nil + return auth.UserAuth{}, &jwt.Token{}, nil } diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index c8015eb37..b9f091b1e 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -17,10 +17,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { @@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { } // this has been validated and parsed by ValidateAndParseToken - userAuth := nbcontext.UserAuth{ + userAuth := nbauth.UserAuth{ AccountId: account.Id, Domain: domain, UserId: userId, @@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tests := []struct { name string tokenFunc func() string - expected *nbcontext.UserAuth // nil indicates expected error + expected *nbauth.UserAuth // nil indicates expected error }{ { name: "Valid with custom claims", @@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", AccountId: "account-id|567", Domain: "http://localhost", @@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", }, }, diff --git a/management/server/cache/idp.go b/management/server/cache/idp.go index 1b31ff82a..19dfc0f38 100644 --- a/management/server/cache/idp.go +++ b/management/server/cache/idp.go @@ -18,6 +18,7 @@ const ( DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days DefaultIDPCacheCleanupInterval = 30 * time.Minute + DefaultIDPCacheOpenConn = 100 ) // UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects. diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index 3fcfbb11a..0e8061e94 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } - cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) + cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn) if err != nil { t.Fatalf("couldn't create cache store: %s", err) } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 1c141a180..54b0242de 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -3,6 +3,7 @@ package cache import ( "context" "fmt" + "math" "os" "time" @@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. -func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) { +func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { redisAddr := os.Getenv(RedisStoreEnvVar) if redisAddr != "" { - return getRedisStore(ctx, redisAddr) + return getRedisStore(ctx, redisAddr, maxConn) } goc := gocache.New(maxTimeout, cleanupInterval) return gocache_store.NewGoCache(goc), nil } -func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) { +func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { return nil, fmt.Errorf("parsing redis cache url: %s", err) } - options.MaxIdleConns = 6 - options.MinIdleConns = 3 - options.MaxActiveConns = 100 + options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns + options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns + options.MaxActiveConns = maxConn + options.ConnMaxIdleTime = 30 * time.Minute + options.ConnMaxLifetime = 0 + options.PoolTimeout = 10 * time.Second redisClient := redis.NewClient(options) subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() diff --git a/management/server/cache/store_test.go b/management/server/cache/store_test.go index f49dd6bbd..1b64fd70d 100644 --- a/management/server/cache/store_test.go +++ b/management/server/cache/store_test.go @@ -15,7 +15,7 @@ import ( ) func TestMemoryStore(t *testing.T) { - memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create memory store: %s", err) } @@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) { func TestRedisStoreConnectionFailure(t *testing.T) { t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379") - _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond) + _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100) if err == nil { t.Fatal("getting redis cache store should return error") } @@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) { } t.Setenv(cache.RedisStoreEnvVar, redisURL) - redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create redis store: %s", err) } diff --git a/management/server/context/auth.go b/management/server/context/auth.go index 5cb28ddb7..cc59b8a63 100644 --- a/management/server/context/auth.go +++ b/management/server/context/auth.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "net/http" - "time" + + "github.com/netbirdio/netbird/shared/auth" ) type key int @@ -13,45 +14,22 @@ const ( UserAuthContextKey key = iota ) -type UserAuth struct { - // The account id the user is accessing - AccountId string - // The account domain - Domain string - // The account domain category, TBC values - DomainCategory string - // Indicates whether this user was invited, TBC logic - Invited bool - // Indicates whether this is a child account - IsChild bool - - // The user id - UserId string - // Last login time for this user - LastLogin time.Time - // The Groups the user belongs to on this account - Groups []string - - // Indicates whether this user has authenticated with a Personal Access Token - IsPAT bool -} - -func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { +func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) { return GetUserAuthFromContext(r.Context()) } -func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { +func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request { return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) } -func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { - if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { +func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok { return userAuth, nil } - return UserAuth{}, fmt.Errorf("user auth not in context") + return auth.UserAuth{}, fmt.Errorf("user auth not in context") } -func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { +func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context { //nolint ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId) //nolint diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..baf6debc3 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,54 +3,23 @@ package server import ( "context" "slices" - "sync" log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort ) -const dnsForwarderPortMinVersion = "v0.59.0" - -// DNSConfigCache is a thread-safe cache for DNS configuration components -type DNSConfigCache struct { - NameServerGroups sync.Map -} - -// GetNameServerGroup retrieves a cached name server group -func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { - if c == nil { - return nil, false - } - if value, ok := c.NameServerGroups.Load(key); ok { - return value.(*proto.NameServerGroup), true - } - return nil, false -} - -// SetNameServerGroup stores a name server group in the cache -func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { - if c == nil { - return - } - c.NameServerGroups.Store(key, value) -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) @@ -191,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } - -// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. -// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. -func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { - if len(peers) == 0 { - return oldForwarderPort - } - - reqVer := semver.Canonical(requiredVersion) - - // Check if all peers have the required version or newer - for _, peer := range peers { - - // Development version is always supported - if peer.Meta.WtVersion == "development" { - continue - } - peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) - if peerVersion == "" { - // If any peer doesn't have version info, return 0 - return oldForwarderPort - } - - // Compare versions - if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort - } - } - - // All peers have the required version or newer - return dnsForwarderPort -} - -// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { - protoUpdate := &proto.DNSConfig{ - ServiceEnable: update.ServiceEnable, - CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), - NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), - ForwarderPort: forwardPort, - } - - for _, zone := range update.CustomZones { - protoZone := convertToProtoCustomZone(zone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } - - for _, nsGroup := range update.NameServerGroups { - cacheKey := nsGroup.ID - if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) - } else { - protoGroup := convertToProtoNameServerGroup(nsGroup) - cache.SetNameServerGroup(cacheKey, protoGroup) - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) - } - } - - return protoUpdate -} - -// Helper function to convert nbdns.CustomZone to proto.CustomZone -func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { - protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), - } - for _, record := range zone.Records { - protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ - Name: record.Name, - Type: int64(record.Type), - Class: record.Class, - TTL: int64(record.TTL), - RData: record.RData, - }) - } - return protoZone -} - -// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup -func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { - protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, - SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, - NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), - } - for _, ns := range nsGroup.NameServers { - protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ - IP: ns.IP.String(), - Port: int64(ns.Port), - NSType: int64(ns.NSType), - }) - } - return protoGroup -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..b5e3f2b99 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "net/netip" - "reflect" "testing" "time" @@ -12,6 +10,11 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -218,7 +221,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { @@ -344,247 +353,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return am.Store.GetAccount(context.Background(), account.Id) } -func generateTestData(size int) nbdns.Config { - config := nbdns.Config{ - ServiceEnable: true, - CustomZones: make([]nbdns.CustomZone, size), - NameServerGroups: make([]*nbdns.NameServerGroup, size), - } - - for i := 0; i < size; i++ { - config.CustomZones[i] = nbdns.CustomZone{ - Domain: fmt.Sprintf("domain%d.com", i), - Records: []nbdns.SimpleRecord{ - { - Name: fmt.Sprintf("record%d", i), - Type: 1, - Class: "IN", - TTL: 3600, - RData: "192.168.1.1", - }, - }, - } - - config.NameServerGroups[i] = &nbdns.NameServerGroup{ - ID: fmt.Sprintf("group%d", i), - Primary: i == 0, - Domains: []string{fmt.Sprintf("domain%d.com", i)}, - SearchDomainsEnabled: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("8.8.8.8"), - Port: 53, - NSType: 1, - }, - }, - } - } - - return config -} - -func BenchmarkToProtocolDNSConfig(b *testing.B) { - sizes := []int{10, 100, 1000} - - for _, size := range sizes { - testData := generateTestData(size) - - b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { - cache := &DNSConfigCache{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) - } - }) - - b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) - } - }) - } -} - -func TestToProtocolDNSConfigWithCache(t *testing.T) { - var cache DNSConfigCache - - // Create two different configs - config1 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.com", - Records: []nbdns.SimpleRecord{ - {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group1", - Name: "Group 1", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, - }, - }, - }, - } - - config2 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.org", - Records: []nbdns.SimpleRecord{ - {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group2", - Name: "Group 2", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, - }, - }, - }, - } - - // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) - - // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) - - // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) - - // Verify that result1 and result3 are identical - if !reflect.DeepEqual(result1, result3) { - t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) - } - - // Verify that result2 is different from result1 and result3 - if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { - t.Errorf("Results should be different for different inputs") - } - - if _, exists := cache.GetNameServerGroup("group1"); !exists { - t.Errorf("Cache should contain name server group 'group1'") - } - - if _, exists := cache.GetNameServerGroup("group2"); !exists { - t.Errorf("Cache should contain name server group 'group2'") - } -} - -func TestComputeForwarderPort(t *testing.T) { - // Test with empty peers list - peers := []*nbpeer.Peer{} - result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { - t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) - } - - // Test with peers that have old versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.26.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { - t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have new versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { - t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) - } - - // Test with peers that have mixed versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { - t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have empty version - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { - t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) - } - - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "development", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { - t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) - } - - // Test with peers that have unknown version string - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "unknown", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { - t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) - } -} - func TestDNSAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { @@ -600,9 +370,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates diff --git a/management/server/event_test.go b/management/server/event_test.go index 8c56fd3f6..420e69866 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac } func TestDefaultAccountManager_GetEvents(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { return } diff --git a/management/server/group.go b/management/server/group.go index 487cb6d97..84e641f26 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err != nil { return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) @@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } } - newGroup.AccountID = accountID - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) if err != nil { return err @@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac if err == nil && oldGroup != nil { addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) + + if oldGroup.Name != newGroup.Name { + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "old_name": oldGroup.Name, + "new_name": newGroup.Name, + } + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta) + }) + } } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -354,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) return nil } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peerID := range addedPeers { peer, ok := peers[peerID] diff --git a/management/server/group_test.go b/management/server/group_test.go index 31ff29cbc..4935dac5d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -37,7 +37,7 @@ const ( ) func TestDefaultAccountManager_CreateGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Fatalf("failed to create account manager: %s", err) } @@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroups(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) assert.NoError(t, err, "Failed to create account manager") manager, account, err := initTestGroupAccount(am) @@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t } func TestGroupAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving a group that is not linked to any resource should not update account peers @@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } func Test_AddPeerToGroup(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) { } func Test_AddPeerToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) { } func Test_AddPeerAndAddToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP { } func Test_IncrementNetworkSerial(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3d4de31d0..b7c6c113c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,11 +4,16 @@ import ( "context" "fmt" "net/http" + "os" + "strconv" + "time" "github.com/gorilla/mux" "github.com/rs/cors" + log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" @@ -16,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" + nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -34,11 +40,15 @@ import ( nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - nbpeers "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" +const ( + apiPrefix = "/api" + rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" + rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" + rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" +) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler( @@ -56,13 +66,46 @@ func NewAPIHandler( permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, + networkMapController network_map.Controller, ) (http.Handler, error) { + var rateLimitingConfig *middleware.RateLimiterConfig + if os.Getenv(rateLimitingEnabledKey) == "true" { + rpm := 6 + if v := os.Getenv(rateLimitingRPMKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) + } else { + rpm = value + } + } + + burst := 500 + if v := os.Getenv(rateLimitingBurstKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) + } else { + burst = value + } + } + + rateLimitingConfig = &middleware.RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + } + } + authMiddleware := middleware.NewAuthMiddleware( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, + rateLimitingConfig, + appMetrics.GetMeter(), ) corsMiddleware := cors.AllowAll() @@ -80,7 +123,7 @@ func NewAPIHandler( } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index f1552d0ea..3797b0512 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -3,12 +3,15 @@ package accounts import ( "context" "encoding/json" + "fmt" "net/http" "net/netip" "time" "github.com/gorilla/mux" + goversion "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/settings" @@ -26,7 +29,9 @@ const ( // MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16) MinNetworkBitsIPv4 = 28 // MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges - MinNetworkBitsIPv6 = 120 + MinNetworkBitsIPv6 = 120 + disableAutoUpdate = "disabled" + autoUpdateLatestVersion = "latest" ) // handler is a handler that handles the server.Account HTTP endpoints @@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } +func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) { + returnSettings := &types.Settings{ + PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), + } + + if req.Settings.Extra != nil { + returnSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, + FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, + FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, + FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, + } + } + + if req.Settings.JwtGroupsEnabled != nil { + returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled + } + if req.Settings.GroupsPropagationEnabled != nil { + returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled + } + if req.Settings.JwtGroupsClaimName != nil { + returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName + } + if req.Settings.JwtAllowGroups != nil { + returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups + } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } + if req.Settings.DnsDomain != nil { + returnSettings.DNSDomain = *req.Settings.DnsDomain + } + if req.Settings.LazyConnectionEnabled != nil { + returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + } + if req.Settings.AutoUpdateVersion != nil { + _, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion) + if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion || + *req.Settings.AutoUpdateVersion == disableAutoUpdate || + err == nil { + returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion + } else if *req.Settings.AutoUpdateVersion != "" { + return nil, fmt.Errorf("invalid AutoUpdateVersion") + } + } + + return returnSettings, nil +} + // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) @@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - settings := &types.Settings{ - PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, - PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), - RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), - } - - if req.Settings.Extra != nil { - settings.Extra = &types.ExtraSettings{ - PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, - UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, - FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, - FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, - FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, - } - } - - if req.Settings.JwtGroupsEnabled != nil { - settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled - } - if req.Settings.GroupsPropagationEnabled != nil { - settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled - } - if req.Settings.JwtGroupsClaimName != nil { - settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName - } - if req.Settings.JwtAllowGroups != nil { - settings.JWTAllowGroups = *req.Settings.JwtAllowGroups - } - if req.Settings.RoutingPeerDnsResolutionEnabled != nil { - settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled - } - if req.Settings.DnsDomain != nil { - settings.DNSDomain = *req.Settings.DnsDomain - } - if req.Settings.LazyConnectionEnabled != nil { - settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + settings, err := h.updateAccountRequestSettings(req) + if err != nil { + util.WriteError(r.Context(), err, w) + return } if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" { prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange) @@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, + AutoUpdateVersion: &settings.AutoUpdateVersion, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 4b9b79fdc..2e48ac83e 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -120,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: true, expectedID: accountID, @@ -142,6 +144,30 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + }, + expectedArray: false, + expectedID: accountID, + }, + { + name: "PutAccount OK with autoUpdateVersion", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + AutoUpdateVersion: sr("latest"), }, expectedArray: false, expectedID: accountID, @@ -164,6 +190,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, @@ -186,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, @@ -208,6 +236,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, @@ -236,7 +265,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, AccountId: accountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 08a0b2afd..67638aea5 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 42b519c29..a027c067e 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, AccountId: testingDNSSettingsAccount.Id, Domain: testingDNSSettingsAccount.Domain, diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index d49b6c7e0..4716782f3 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", AccountId: testNSGroupAccountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index a0695fa3f..923a24e31 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -14,11 +14,12 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_account", diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e861e873c..56ccc9d0b 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -11,10 +11,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account @@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } accountID, userID := userAuth.AccountId, userAuth.UserId + // Check if filtering by name + groupName := r.URL.Query().Get("name") + if groupName != "" { + // Get single group by name + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + // Return as array with single element to maintain API consistency + groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)} + util.WriteJSONObject(r.Context(), w, groupsResponse) + return + } + + // Get all groups groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 34694ec8c..458a15c11 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -19,12 +19,13 @@ import ( "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -59,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return group, nil }, + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + groups := []*types.Group{ + {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + {ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, + } + + groups = append(groups, initGroups...) + + return groups, nil + }, GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } - return nil, fmt.Errorf("unknown group name") + return nil, status.Errorf(status.NotFound, "unknown group name") }, GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { return maps.Values(TestPeers), nil @@ -122,7 +134,7 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -248,7 +260,7 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -286,6 +298,84 @@ func TestWriteGroup(t *testing.T) { } } +func TestGetAllGroups(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + expectedCount int + }{ + { + name: "Get All Groups", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups", + expectedStatus: http.StatusOK, + expectedCount: 3, // id-jwt-group, id-existed, id-all + }, + { + name: "Get Group By Name - Existing", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups?name=All", + expectedStatus: http.StatusOK, + expectedCount: 1, + }, + { + name: "Get Group By Name - Not Found", + expectedBody: false, + requestType: http.MethodGet, + requestPath: "/api/groups?name=NonExistent", + expectedStatus: http.StatusNotFound, + }, + } + + p := initGroupTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) + + router := mux.NewRouter() + router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + return + } + + if !tc.expectedBody { + return + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var groups []api.Group + if err = json.Unmarshal(content, &groups); err != nil { + t.Fatalf("Response is not in correct json format; %v", err) + } + + assert.Equal(t, tc.expectedCount, len(groups)) + }) + } +} + func TestDeleteGroup(t *testing.T) { tt := []struct { name string @@ -330,7 +420,7 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index d7b598a5d..f99eca794 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -12,15 +12,15 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/shared/management/status" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) // handler is a handler that returns networks of the account diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 59396dceb..c31729a39 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -8,10 +8,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type resourceHandler struct { diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 2e64c637f..c311a29fe 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -7,10 +7,10 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type routersHandler struct { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..a5c9ab0ac 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -23,11 +24,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager account.Manager + accountManager account.Manager + networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - peersHandler := NewHandler(accountManager) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { + peersHandler := NewHandler(accountManager, networkMapController) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -36,25 +38,13 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { return &Handler{ - accountManager: accountManager, + accountManager: accountManager, + networkMapController: networkMapController, } } -func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { - peerToReturn := peer.Copy() - if peer.Status.Connected { - // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected - // This may happen after server restart when not all peers are yet connected - if !h.accountManager.HasConnectedChannel(peer.ID) { - peerToReturn.Status.Connected = false - } - } - - return peerToReturn, nil -} - func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { @@ -62,23 +52,18 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(ctx, err, w) - return - } settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) if err != nil { util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +71,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -137,7 +124,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -147,16 +134,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -224,38 +212,35 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) + respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,17 +289,17 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - dnsDomain := h.accountManager.GetDNSDomain(account.Settings) + dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } @@ -384,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) PortRanges: []types.RulePortRange{portRange}, }}, } + if protocol == types.PolicyRuleProtocolNetbirdSSH { + policy.Rules[0].AuthorizedUser = userAuth.UserId + } _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) if err != nil { @@ -430,13 +418,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -464,7 +452,25 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { @@ -472,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn if osVersion == "" { osVersion = peer.Meta.Core } - return &api.PeerBatch{ CreatedAt: peer.CreatedAt, Id: peer.ID, @@ -501,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } } diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 94564113f..55e779ff0 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -14,12 +14,15 @@ import ( "time" "github.com/gorilla/mux" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,7 +39,7 @@ const ( serviceUser = "service_user" ) -func initTestMetaData(peers ...*nbpeer.Peer) *Handler { +func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { @@ -99,6 +102,14 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } + ctrl := gomock.NewController(t) + + networkMapController := network_map.NewMockController(ctrl) + networkMapController.EXPECT(). + GetDNSDomain(gomock.Any()). + Return("domain"). + AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -187,6 +198,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return account.Settings, nil }, }, + networkMapController: networkMapController, } } @@ -249,14 +261,6 @@ func TestGetPeers(t *testing.T) { expectedArray: false, expectedPeer: peer, }, - { - name: "GetPeer with no update channel", - requestType: http.MethodGet, - requestPath: "/api/peers/" + peer1.ID, - expectedStatus: http.StatusOK, - expectedArray: false, - expectedPeer: expectedPeer1, - }, { name: "PutPeer", requestType: http.MethodPut, @@ -270,14 +274,14 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer, peer1) + p := initTestMetaData(t, peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "admin_user", Domain: "hotmail.com", AccountId: "test_id", @@ -316,8 +320,6 @@ func TestGetPeers(t *testing.T) { for _, peer := range respBody { if peer.Id == testPeerID { got = peer - } else { - assert.Equal(t, peer.Connected, false) } } @@ -331,14 +333,14 @@ func TestGetPeers(t *testing.T) { t.Log(got) - assert.Equal(t, got.Name, tc.expectedPeer.Name) - assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) - assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) - assert.Equal(t, got.Os, "OS core") - assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) - assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) - assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) - assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber) + assert.Equal(t, tc.expectedPeer.Name, got.Name) + assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version) + assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip) + assert.Equal(t, "OS core", got.Os) + assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled) + assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled) + assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected) + assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber) }) } } @@ -374,7 +376,7 @@ func TestGetAccessiblePeers(t *testing.T) { UserID: regularUser, } - p := initTestMetaData(peer1, peer2, peer3) + p := initTestMetaData(t, peer1, peer2, peer3) tt := []struct { name string @@ -425,7 +427,7 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", @@ -477,7 +479,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { }, } - p := initTestMetaData(testPeer) + p := initTestMetaData(t, testPeer) tt := []struct { name string @@ -508,7 +510,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody))) req.Header.Set("Content-Type", "application/json") - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index cedd5ac88..094a36e38 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -16,12 +16,13 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/util" ) @@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index cb6995793..a2d656a47 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -9,11 +9,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 4d6bad5e3..e4d1d73df 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account @@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: pr.Protocol = types.PolicyRuleProtocolICMP + case api.PolicyRuleUpdateProtocolNetbirdSsh: + pr.Protocol = types.PolicyRuleProtocolNetbirdSSH default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s } } + if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 { + for _, sourceGroupID := range pr.Sources { + _, ok := (*rule.AuthorizedGroups)[sourceGroupID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w) + return + } + } + pr.AuthorizedGroups = *rule.AuthorizedGroups + } + // validate policy object if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP { if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { @@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { DestinationResource: r.DestinationResource.ToAPIResponse(), } + if len(r.AuthorizedGroups) != 0 { + authorizedGroupsCopy := r.AuthorizedGroups + rule.AuthorizedGroups = &authorizedGroupsCopy + } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index fd39ae2a3..ca5a0a6ab 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initPoliciesTestData(policies ...*types.Policy) *handler { @@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 3ebc4d1e1..744cde10b 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index c644b533a..35198da32 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -16,9 +16,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -45,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } return postureChecks, nil @@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 466a7987f..a44d81e3e 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" @@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: testAccountID, diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 2287dadfe..d267b6eea 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 7b46b486b..b137b6dd1 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -15,10 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, Domain: "hotmail.com", AccountId: "testAccountId", diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index bae07af4a..867db3ca9 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -8,10 +8,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 92544c56d..7cda14468 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -17,10 +17,11 @@ import ( "github.com/netbirdio/netbird/management/server/util" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index e08004218..37f0a6c1d 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -128,7 +129,7 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") @@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) { tt := []struct { name string expectedStatus int - requestAuth nbcontext.UserAuth + requestAuth auth.UserAuth expectedResult *api.User }{ { @@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) { }, { name: "user not found", - requestAuth: nbcontext.UserAuth{UserId: "not-found"}, + requestAuth: auth.UserAuth{UserId: "not-found"}, expectedStatus: http.StatusNotFound, }, { name: "not of account", - requestAuth: nbcontext.UserAuth{UserId: "not-of-account"}, + requestAuth: auth.UserAuth{UserId: "not-of-account"}, expectedStatus: http.StatusForbidden, }, { name: "blocked user", - requestAuth: nbcontext.UserAuth{UserId: "blocked-user"}, + requestAuth: auth.UserAuth{UserId: "blocked-user"}, expectedStatus: http.StatusForbidden, }, { name: "service user", - requestAuth: nbcontext.UserAuth{UserId: "service-user"}, + requestAuth: auth.UserAuth{UserId: "service-user"}, expectedStatus: http.StatusForbidden, }, { name: "owner", - requestAuth: nbcontext.UserAuth{UserId: "owner"}, + requestAuth: auth.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "owner", @@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "regular user", - requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, + requestAuth: auth.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "regular-user", @@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "admin user", - requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, + requestAuth: auth.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "admin-user", @@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "restricted user", - requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + requestAuth: auth.UserAuth{UserId: "restricted-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "restricted-user", @@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) { req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } @@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) { req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..38cf0c290 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -9,40 +9,62 @@ import ( "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" - "github.com/netbirdio/netbird/management/server/auth" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) -type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) -type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error -type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager + authManager serverauth.Manager ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc + rateLimiter *APIRateLimiter + patUsageTracker *PATUsageTracker } // NewAuthMiddleware instance constructor func NewAuthMiddleware( - authManager auth.Manager, + authManager serverauth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, + rateLimiterConfig *RateLimiterConfig, + meter metric.Meter, ) *AuthMiddleware { + var rateLimiter *APIRateLimiter + if rateLimiterConfig != nil { + rateLimiter = NewAPIRateLimiter(rateLimiterConfig) + } + + var patUsageTracker *PATUsageTracker + if meter != nil { + var err error + patUsageTracker, err = NewPATUsageTracker(context.Background(), meter) + if err != nil { + log.Errorf("Failed to create PAT usage tracker: %s", err) + } + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, + rateLimiter: rateLimiter, + patUsageTracker: patUsageTracker, } } @@ -53,18 +75,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return } - auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := strings.ToLower(auth[0]) + authHeader := strings.Split(r.Header.Get("Authorization"), " ") + authType := strings.ToLower(authHeader[0]) // fallback to token when receive pat as bearer - if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") { authType = "token" - auth[0] = authType + authHeader[0] = authType } switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, auth) + request, err := m.checkJWTFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) @@ -73,10 +95,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, request) case "token": - request, err := m.checkPATFromRequest(r, auth) + request, err := m.checkPATFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) + // Check if it's a status error, otherwise default to Unauthorized + if _, ok := status.FromError(err); !ok { + err = status.Errorf(status.Unauthorized, "token invalid") + } + util.WriteError(r.Context(), err, w) return } h.ServeHTTP(w, request) @@ -88,8 +114,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromJWTRequest(auth) +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { @@ -115,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h } if userAuth.AccountId != accountId { - log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) + log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) userAuth.AccountId = accountId } @@ -139,12 +165,22 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromPATRequest(auth) +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { return r, fmt.Errorf("error extracting token: %w", err) } + if m.patUsageTracker != nil { + m.patUsageTracker.IncrementUsage(token) + } + + if m.rateLimiter != nil { + if !m.rateLimiter.Allow(token) { + return r, status.Errorf(status.TooManyRequests, "too many requests") + } + } + ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { @@ -159,7 +195,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, err } - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: user.Id, AccountId: user.AccountID, Domain: accDomain, diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..ba4d16796 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -12,11 +12,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -27,7 +28,9 @@ const ( domainCategory = "domainCategory" userID = "userID" tokenID = "tokenID" + tokenID2 = "tokenID2" PAT = "nbp_PAT" + PAT2 = "nbp_PAT2" JWT = "JWT" wrongToken = "wrongToken" ) @@ -49,6 +52,15 @@ var testAccount = &types.Account{ CreatedAt: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()), }, + tokenID2: { + ID: tokenID2, + Name: "My second token", + HashedToken: "someHash2", + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), + CreatedBy: userID, + CreatedAt: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), + }, }, }, }, @@ -58,12 +70,15 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } + if token == PAT2 { + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil + } return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) { if token == JWT { - return nbcontext.UserAuth{ + return nbauth.UserAuth{ UserId: userID, AccountId: accountID, Domain: testAccount.Domain, @@ -77,17 +92,17 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA Valid: true, }, nil } - return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") + return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { - if token == tokenID { + if token == tokenID || token == tokenID2 { return nil } return fmt.Errorf("Should never get reached") } -func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } @@ -183,15 +198,17 @@ func TestAuthMiddleware_Handler(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -221,18 +238,290 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +func TestAuthMiddleware_RateLimiting(t *testing.T) { + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) { + // Configure rate limiter: 10 requests per minute with burst of 5 + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 5, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make burst requests - all should succeed + successCount := 0 + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 5, successCount, "All burst requests should succeed") + + // The 6th request should fail (exceeded burst) + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited") + }) + + t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) { + // Configure very low rate limit: 1 request per minute + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request should fail (rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) + + t.Run("Bearer Token Not Rate Limited", func(t *testing.T) { + // Configure strict rate limit + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make multiple requests with Bearer token - all should succeed + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)") + }) + + t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) { + // Configure rate limiter + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use first PAT token + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed") + + // Second request with same token should fail + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited") + + // Use second PAT token - should succeed because it has independent rate limit + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)") + + // Second request with PAT2 should also be rate limited + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited") + + // JWT should still work (not rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)") + }) + + t.Run("Rate Limiter Cleanup", func(t *testing.T) { + // Configure rate limiter with short cleanup interval and TTL for testing + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: 100 * time.Millisecond, + LimiterTTL: 200 * time.Millisecond, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request - should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request immediately - should fail (burst exhausted) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + + // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer) + time.Sleep(400 * time.Millisecond) + + // After cleanup, the limiter should be removed and recreated with full burst capacity + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)") + + // Verify it's a fresh limiter by checking burst is reset + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") + }) +} + func TestAuthMiddleware_Handler_Child(t *testing.T) { tt := []struct { name string path string authHeader string - expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status + expectedUserAuth *nbauth.UserAuth // nil expects 401 response status }{ { name: "Valid PAT Token", path: "/test", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -244,7 +533,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -257,7 +546,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token", path: "/test", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -269,7 +558,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token with child", path: "/test?account=xyz", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -288,15 +577,17 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/pat_usage_tracker.go b/management/server/http/middleware/pat_usage_tracker.go new file mode 100644 index 000000000..331c288e7 --- /dev/null +++ b/management/server/http/middleware/pat_usage_tracker.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "context" + "maps" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" +) + +// PATUsageTracker tracks PAT usage metrics +type PATUsageTracker struct { + usageCounters map[string]int64 + mu sync.Mutex + stopChan chan struct{} + ctx context.Context + histogram metric.Int64Histogram +} + +// NewPATUsageTracker creates a new PAT usage tracker with metrics +func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) { + histogram, err := meter.Int64Histogram( + "management.pat.usage_distribution", + metric.WithUnit("1"), + metric.WithDescription("Distribution of PAT token usage counts per minute"), + ) + if err != nil { + return nil, err + } + + tracker := &PATUsageTracker{ + usageCounters: make(map[string]int64), + stopChan: make(chan struct{}), + ctx: ctx, + histogram: histogram, + } + + go tracker.reportLoop() + + return tracker, nil +} + +// IncrementUsage increments the usage counter for a given token +func (t *PATUsageTracker) IncrementUsage(token string) { + t.mu.Lock() + defer t.mu.Unlock() + t.usageCounters[token]++ +} + +// reportLoop reports the usage buckets every minute +func (t *PATUsageTracker) reportLoop() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.reportUsageBuckets() + case <-t.stopChan: + return + } + } +} + +// reportUsageBuckets reports all token usage counts and resets counters +func (t *PATUsageTracker) reportUsageBuckets() { + t.mu.Lock() + snapshot := maps.Clone(t.usageCounters) + + clear(t.usageCounters) + t.mu.Unlock() + + totalTokens := len(snapshot) + if totalTokens > 0 { + for _, count := range snapshot { + t.histogram.Record(t.ctx, count) + } + log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens) + } +} + +// Stop stops the reporting goroutine +func (t *PATUsageTracker) Stop() { + close(t.stopChan) +} diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go new file mode 100644 index 000000000..a6266d4f3 --- /dev/null +++ b/management/server/http/middleware/rate_limiter.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterConfig holds configuration for the API rate limiter +type RateLimiterConfig struct { + // RequestsPerMinute defines the rate at which tokens are replenished + RequestsPerMinute float64 + // Burst defines the maximum number of requests that can be made in a burst + Burst int + // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs) + CleanupInterval time.Duration + // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal) + LimiterTTL time.Duration +} + +// DefaultRateLimiterConfig returns a default configuration +func DefaultRateLimiterConfig() *RateLimiterConfig { + return &RateLimiterConfig{ + RequestsPerMinute: 100, + Burst: 120, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } +} + +// limiterEntry holds a rate limiter and its last access time +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// APIRateLimiter manages rate limiting for API tokens +type APIRateLimiter struct { + config *RateLimiterConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + stopChan chan struct{} +} + +// NewAPIRateLimiter creates a new API rate limiter with the given configuration +func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { + if config == nil { + config = DefaultRateLimiterConfig() + } + + rl := &APIRateLimiter{ + config: config, + limiters: make(map[string]*limiterEntry), + stopChan: make(chan struct{}), + } + + go rl.cleanupLoop() + + return rl +} + +// Allow checks if a request for the given key (token) is allowed +func (rl *APIRateLimiter) Allow(key string) bool { + limiter := rl.getLimiter(key) + return limiter.Allow() +} + +// Wait blocks until the rate limiter allows another request for the given key +// Returns an error if the context is canceled +func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + limiter := rl.getLimiter(key) + return limiter.Wait(ctx) +} + +// getLimiter retrieves or creates a rate limiter for the given key +func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.RLock() + entry, exists := rl.limiters[key] + rl.mu.RUnlock() + + if exists { + rl.mu.Lock() + entry.lastAccess = time.Now() + rl.mu.Unlock() + return entry.limiter + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + if entry, exists := rl.limiters[key]; exists { + entry.lastAccess = time.Now() + return entry.limiter + } + + requestsPerSecond := rl.config.RequestsPerMinute / 60.0 + limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst) + rl.limiters[key] = &limiterEntry{ + limiter: limiter, + lastAccess: time.Now(), + } + + return limiter +} + +// cleanupLoop periodically removes old limiters that haven't been used recently +func (rl *APIRateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopChan: + return + } + } +} + +// cleanup removes limiters that haven't been used within the TTL period +func (rl *APIRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, entry := range rl.limiters { + if now.Sub(entry.lastAccess) > rl.config.LimiterTTL { + delete(rl.limiters, key) + } + } +} + +// Stop stops the cleanup goroutine +func (rl *APIRateLimiter) Stop() { + close(rl.stopChan) +} + +// Reset removes the rate limiter for a specific key +func (rl *APIRateLimiter) Reset(key string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.limiters, key) +} diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..e8513feb5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,14 +7,22 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" + serverauth "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -22,15 +30,15 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" ) -func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) if err != nil { t.Fatalf("Failed to create test store: %v", err) @@ -42,7 +50,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create metrics: %v", err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) done := make(chan struct{}) if validateUpdate { @@ -62,14 +70,18 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + + ctx := context.Background() + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, MarkPATUsedFunc: authManager.MarkPATUsed, @@ -82,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -90,7 +102,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee return apiHandler, am, done } -func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -100,7 +112,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server } } -func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) { t.Helper() select { @@ -114,8 +126,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up } } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} +func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { + userAuth := auth.UserAuth{} switch token { case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..126a76919 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,137 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..69ea668ad 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -117,7 +127,7 @@ type MockIntegratedValidator struct { ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) } -func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error { return nil } @@ -136,7 +146,11 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..326fbfaf0 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,18 +3,19 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index a34d2086b..42f192c0a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -22,11 +22,16 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -321,99 +326,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ return loginResp, nil } -func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { - testingServerKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testingClientKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testCases := []struct { - name string - inputFlow *config.DeviceAuthorizationFlow - expectedFlow *mgmtProto.DeviceAuthorizationFlow - expectedErrFunc require.ErrorAssertionFunc - expectedErrMSG string - expectedComparisonFunc require.ComparisonAssertionFunc - expectedComparisonMSG string - }{ - { - name: "Testing No Device Flow Config", - inputFlow: nil, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Invalid Device Flow Provider Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "NoNe", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Full Device Flow Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "hosted", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ - Provider: 0, - ProviderConfig: &mgmtProto.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.NoError, - expectedErrMSG: "should not return error", - expectedComparisonFunc: require.Equal, - expectedComparisonMSG: "should match", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - mgmtServer := &GRPCServer{ - wgKey: testingServerKey, - config: &config.Config{ - DeviceAuthorizationFlow: testCase.inputFlow, - }, - } - - message := &mgmtProto.DeviceAuthorizationFlowRequest{} - - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) - require.NoError(t, err, "should be able to encrypt message") - - resp, err := mgmtServer.GetDeviceAuthorizationFlow( - context.TODO(), - &mgmtProto.EncryptedMessage{ - WgPubKey: testingClientKey.PublicKey().String(), - Body: encryptedMSG, - }, - ) - testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) - if testCase.expectedComparisonFunc != nil { - flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) - require.NoError(t, err, "should be able to decrypt") - - testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) - testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) - } - }) - } -} - func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") @@ -427,7 +339,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config t.Fatal(err) } - peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck @@ -451,7 +362,12 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() - accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) + accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { @@ -459,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } - ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } @@ -764,9 +683,38 @@ func Test_LoginPerformance(t *testing.T) { peerLogin := types.PeerLogin{ WireGuardPubKey: key.String(), SSHKey: "random", - Meta: extractPeerMeta(context.Background(), meta), - SetupKey: setupKey.Key, - ConnectionIP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: meta.GetHostname(), + GoOS: meta.GetGoOS(), + Kernel: meta.GetKernel(), + Platform: meta.GetPlatform(), + OS: meta.GetOS(), + OSVersion: meta.GetOSVersion(), + WtVersion: meta.GetNetbirdVersion(), + UIVersion: meta.GetUiVersion(), + KernelVersion: meta.GetKernelVersion(), + SystemSerialNumber: meta.GetSysSerialNumber(), + SystemProductName: meta.GetSysProductName(), + SystemManufacturer: meta.GetSysManufacturer(), + Environment: nbpeer.Environment{ + Cloud: meta.GetEnvironment().GetCloud(), + Platform: meta.GetEnvironment().GetPlatform(), + }, + Flags: nbpeer.Flags{ + RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(), + RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(), + ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(), + DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(), + DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(), + DisableDNS: meta.GetFlags().GetDisableDNS(), + DisableFirewall: meta.GetFlags().GetDisableFirewall(), + BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), + BlockInbound: meta.GetFlags().GetBlockInbound(), + LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + }, + }, + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, } login := func() error { diff --git a/management/server/management_test.go b/management/server/management_test.go index 1a5e47354..648201d4e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -20,12 +20,16 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -176,7 +180,6 @@ func startServer( log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) @@ -199,13 +202,19 @@ func startServer( AnyTimes() permissionsManager := permissions.NewManager(str) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(ctx, str) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) + accountManager, err := server.BuildManager( context.Background(), + nil, str, - peersUpdateManager, + networkMapController, nil, "", - "netbird.selfhosted", eventStore, nil, false, @@ -220,18 +229,19 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer( - context.Background(), + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatalf("failed creating secrets manager: %v", err) + } + mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - peersUpdateManager, secretsManager, nil, - &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, + networkMapController, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..928098dbe 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,6 +2,7 @@ package mock_server import ( "context" + "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" @@ -12,10 +13,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -34,11 +33,11 @@ type MockAccountManager struct { GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -84,7 +83,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -94,7 +93,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error @@ -119,15 +118,16 @@ type MockAccountManager struct { GetStoreFunc func() store.Store UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -177,11 +177,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface @@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { if am.GetUserFromUserAuthFunc != nil { return am.GetUserFromUserAuthFunc(ctx, userAuth) } @@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) } @@ -746,11 +746,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, accountID) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface @@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") } -func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") } @@ -968,21 +968,23 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { return am.GetCurrentUserInfoFunc(ctx, userAuth) } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } -// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface -func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { - // Mock implementation - does nothing -} - func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) } return true } + +func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error { + if am.RecalculateNetworkMapCacheFunc != nil { + return am.RecalculateNetworkMapCacheFunc(ctx, accountID) + } + return nil +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 6c985410c..e3dd8b0b8 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,6 +11,11 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -785,7 +790,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { @@ -975,7 +986,7 @@ func TestValidateDomain(t *testing.T) { } func TestNameServerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup @@ -994,9 +1005,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a nameserver group with a distribution group no peers should not update account peers diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index c6cec6f7e..e2dea2c6b 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -10,8 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be858..6b8cf9412 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -8,11 +8,11 @@ import ( "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 8054d05c6..6be90baa7 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 72b15fd9a..e90c61a97 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -5,8 +5,8 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) type NetworkRouter struct { diff --git a/management/server/peer.go b/management/server/peer.go index 4cf5d1e46..7c48a8052 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -8,8 +8,6 @@ import ( "net" "slices" "strings" - "sync" - "sync/atomic" "time" "github.com/rs/xid" @@ -23,7 +21,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -31,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -95,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers()) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -106,11 +102,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var settings *types.Settings var expired bool @@ -145,9 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.BufferUpdateAccountPeers(ctx, accountID) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil @@ -180,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio } } - log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) if err != nil { @@ -203,7 +195,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var peer *nbpeer.Peer var settings *types.Settings var peerGroupList []string - var requiresPeerUpdates bool var peerLabelChanged bool var sshChanged bool var loginExpirationChanged bool @@ -226,9 +217,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - dnsDomain = am.GetDNSDomain(settings) + dnsDomain = am.networkMapController.GetDNSDomain(settings) - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) + update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -321,10 +312,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - if peerLabelChanged || requiresPeerUpdates { - am.UpdateAccountPeers(ctx, accountID) - } else if sshChanged { - am.UpdateAccountPeer(ctx, accountID, peer.ID) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, fmt.Errorf("notify network map controller of peer update: %w", err) } return peer, nil @@ -350,7 +340,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -363,11 +352,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -387,8 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && userID != activity.SystemInitiator { - am.BufferUpdateAccountPeers(ctx, accountID) + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } return nil @@ -396,41 +380,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return nil, err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - groups := make(map[string][]string) - for groupID, group := range account.Groups { - groups[groupID] = group.Peers - } - - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, err - } - - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return networkMap, nil + return am.networkMapController.GetNetworkMap(ctx, peerID) } // GetPeerNetwork returns the Network for a given peer @@ -584,7 +534,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -634,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually - am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -684,28 +629,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) + if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } - return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) + return p, nmap, pc, err } func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { @@ -720,12 +661,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start)) - }() - +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -735,7 +671,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -785,14 +721,17 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil }) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - am.BufferUpdateAccountPeers(ctx, accountID) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) + } } - return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) + return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -914,10 +853,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.BufferUpdateAccountPeers(ctx, accountID) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + return p, nmap, pc, err } // getPeerPostureChecks returns the posture checks for the peer. @@ -1009,57 +952,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start)) - }() - - if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, nil, nil, err - } - - emptyMap := &types.NetworkMap{ - Network: network.Copy(), - } - return peer, emptyMap, nil, nil - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, nil, nil, err - } - - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return peer, networkMap, postureChecks, nil -} - func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { @@ -1083,7 +975,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return fmt.Errorf("failed to get account settings: %w", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings))) return nil } @@ -1165,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun } for _, p := range userPeers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers()) for _, aclPeer := range aclPeers { if aclPeer.ID == peer.ID { return peer, nil @@ -1179,209 +1071,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return - } - - globalStart := time.Now() - - hasPeersConnected := false - for _, peer := range account.Peers { - if am.peersUpdateManager.HasChannel(peer.ID) { - hasPeersConnected = true - break - } - - } - - if !hasPeersConnected { - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) - return - } - - var wg sync.WaitGroup - semaphore := make(chan struct{}, 10) - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - for _, peer := range account.Peers { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) - continue - } - - wg.Add(1) - semaphore <- struct{}{} - go func(p *nbpeer.Peer) { - defer wg.Done() - defer func() { <-semaphore }() - - start := time.Now() - - postureChecks, err := am.getPeerPostureChecks(account, p.ID) - if err != nil { - log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) - return - } - - am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) - start = time.Now() - - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - - am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) - start = time.Now() - - proxyNetworkMap, ok := proxyNetworkMaps[p.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - - peerGroups := account.GetPeerGroups(p.ID) - start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) - am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) - }(peer) - } - - // - - wg.Wait() - if am.metrics != nil { - am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart)) - } -} - -type bufferUpdate struct { - mu sync.Mutex - next *time.Timer - update atomic.Bool + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) - - bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) - b := bufUpd.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - am.UpdateAccountPeers(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - if b.next == nil { - b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { - am.UpdateAccountPeers(ctx, accountID) - }) - return - } - b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) - }() + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) - return - } - - peer := account.GetPeer(peerId) - if peer == nil { - log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) - return - } - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - postureChecks, err := am.getPeerPostureChecks(account, peerId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) - return - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - - extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) - return - } - - peerGroups := account.GetPeerGroups(peerId) - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1527,16 +1227,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { @@ -1546,14 +1236,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err != nil { return nil, err } - dnsDomain := am.GetDNSDomain(settings) - - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, err - } - - dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { @@ -1587,25 +1270,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - NetworkMap: &types.NetworkMap{}, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) @@ -1614,14 +1278,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return peerDeletedEvents, nil } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} - } - return labelMap -} - // validatePeerDelete checks if the peer can be deleted. func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error { linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 42b3244ae..752563299 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,7 +13,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -25,10 +24,16 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" @@ -168,7 +173,16 @@ func TestPeer_SessionExpired(t *testing.T) { } func TestAccountManager_GetNetworkMap(t *testing.T) { - manager, err := createManager(t) + testGetNetworkMapGeneral(t) +} + +func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testGetNetworkMapGeneral(t) +} + +func testGetNetworkMapGeneral(t *testing.T) { + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -240,7 +254,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { // TODO: disable until we start use policy again t.Skip() - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -417,7 +431,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } func TestAccountManager_GetPeerNetwork(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -478,7 +492,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { } func TestDefaultAccountManager_GetPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -665,7 +679,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -733,12 +747,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } -func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { +func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) { b.Helper() - manager, err := createManager(b) + manager, updateManager, err := createManager(b) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } accountID := "test_account" @@ -789,7 +803,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ips := account.GetTakenIPs() peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } peerKey, _ := wgtypes.GeneratePrivateKey() @@ -895,10 +909,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou err = manager.Store.SaveAccount(context.Background(), account) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } - return manager, accountID, regularUser, nil + return manager, updateManager, accountID, regularUser, nil } func BenchmarkGetPeers(b *testing.B) { @@ -919,7 +933,7 @@ func BenchmarkGetPeers(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -959,7 +973,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -971,14 +985,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) - for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels - b.ResetTimer() start := time.Now() @@ -1003,7 +1013,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers_Experimental(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") + testUpdateAccountPeers(t) +} + func TestUpdateAccountPeers(t *testing.T) { + testUpdateAccountPeers(t) +} + +func testUpdateAccountPeers(t *testing.T) { testCases := []struct { name string peers int @@ -1019,7 +1038,7 @@ func TestUpdateAccountPeers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) if err != nil { t.Fatalf("Failed to setup test account manager: %v", err) } @@ -1031,20 +1050,19 @@ func TestUpdateAccountPeers(t *testing.T) { t.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]chan *network_map.UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels manager.UpdateAccountPeers(ctx, account.Id) for _, channel := range peerChannels { update := <-channel assert.Nil(t, update.Update.NetbirdConfig) - assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) - assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) + assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } }) } @@ -1079,7 +1097,7 @@ func TestToSyncResponse(t *testing.T) { DNSLabel: "peer1", SSHKey: "peer1-ssh-key", } - turnRelayToken := &Token{ + turnRelayToken := &grpc.Token{ Payload: "turn-user", Signature: "turn-pass", } @@ -1159,9 +1177,9 @@ func TestToSyncResponse(t *testing.T) { }, }, } - dnsCache := &DNSConfigCache{} + dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1212,6 +1230,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + //nolint assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) @@ -1271,7 +1290,12 @@ func Test_RegisterPeerByUser(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1351,7 +1375,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1499,7 +1528,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1548,6 +1582,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1573,7 +1608,12 @@ func Test_LoginPeer(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1706,7 +1746,7 @@ func Test_LoginPeer(t *testing.T) { } func TestPeerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) @@ -1763,13 +1803,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) { var peer5 *nbpeer.Peer var peer6 *nbpeer.Peer - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) @@ -1790,7 +1831,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1815,7 +1856,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1871,6 +1912,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) { }) t.Run("validator requires no update", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") + requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { return update, false, nil } @@ -2072,7 +2115,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { } func Test_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2169,7 +2212,7 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2257,136 +2300,8 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, uint64(totalPeers), account.Network.Serial) } -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} - func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2423,7 +2338,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2457,7 +2372,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { } func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2522,7 +2437,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go deleted file mode 100644 index cb135f4ac..000000000 --- a/management/server/peers/manager.go +++ /dev/null @@ -1,68 +0,0 @@ -package peers - -//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -type Manager interface { - GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) - GetPeerAccountID(ctx context.Context, peerID string) (string, error) - GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) -} - -type managerImpl struct { - store store.Store - permissionsManager permissions.Manager -} - -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { - return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - } -} - -func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) -} - -func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - } - - return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") -} - -func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { - return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) -} - -func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) -} diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/policy.go b/management/server/policy.go index 9e4b3f73a..3e84c3d10 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" @@ -252,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin return validIDs } - -// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - fwRule := &proto.FirewallRule{ - PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, - } - - if shouldUsePortRange(fwRule) { - fwRule.PortInfo = rule.PortRange.ToProto() - } - - result[i] = fwRule - } - return result -} - -func shouldUsePortRange(rule *proto.FirewallRule) bool { - return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) -} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4a08f4c33..a3f987732 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers()) assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 8) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", @@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) { PolicyID: "RuleDefault", }, { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", @@ -413,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }) t.Run("check port ranges support for older peers", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 1) assert.Contains(t, peers, account.Peers["peerI"]) @@ -539,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -569,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -601,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -623,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -821,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -831,12 +927,60 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) - assert.Len(t, firewallRules, 1) + assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.80.39", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.21.56", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", @@ -848,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -858,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -873,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -900,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) @@ -991,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { } func TestPolicyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -1020,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var policyWithGroupRulesNoPeers *types.Policy diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index d65dc5045..f0bbbc32e 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,8 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index 411f4c2c6..2ef97a066 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error { func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..9a743eb8c 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,19 +2,15 @@ package server import ( "context" - "errors" - "fmt" "slices" "github.com/rs/xid" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -136,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { - peerPostureChecks := make(map[string]*posture.Checks) - - if len(account.PostureChecks) == 0 { - return nil, nil - } - - for _, policy := range account.Policies { - if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { - continue - } - - if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { - return nil, err - } - } - - return maps.Values(peerPostureChecks), nil -} - // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) @@ -183,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St // validatePostureChecks validates the posture checks. func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. @@ -211,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account return nil } -// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) - if err != nil { - return err - } - - if !isInGroup { - return nil - } - - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.GetPostureChecks(sourcePostureCheckID) - if postureCheck == nil { - return errors.New("failed to add policy posture checks: posture checks not found") - } - peerPostureChecks[sourcePostureCheckID] = postureCheck - } - - return nil -} - -// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, sourceGroup := range rule.Sources { - group := account.GetGroup(sourceGroup) - if group == nil { - return false, fmt.Errorf("failed to check peer in policy source group: group not found") - } - - if slices.Contains(group.Peers, peerID) { - return true, nil - } - } - } - - return false, nil -} - // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 67760d55a..13152ed12 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -21,7 +21,7 @@ const ( ) func TestDefaultAccountManager_PostureCheck(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er } func TestPostureCheckAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) postureCheckA := &posture.Checks{ @@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy where destination has peers but source does not // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + updateManager.CloseChannel(context.Background(), peer2.ID) }) _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ @@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } func TestArePostureCheckChangesAffectPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestPostureChecksAccount(manager) diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..2b4f11d05 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -372,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID return groupsMap, nil } -func toProtocolRoute(route *route.Route) *proto.Route { - return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, - SkipAutoApply: route.SkipAutoApply, - } -} - -func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0, len(routes)) - for _, r := range routes { - protoRoutes = append(protoRoutes, toProtocolRoute(r)) - } - return protoRoutes -} - // getPlaceholderIP returns a placeholder IP address for the route if domains are used func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { - result := make([]*proto.RouteFirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - result[i] = &proto.RouteFirewallRule{ - SourceRanges: rule.SourceRanges, - Action: getProtoAction(rule.Action), - Destination: rule.Destination, - Protocol: getProtoProtocol(rule.Protocol), - PortInfo: getProtoPortInfo(rule), - IsDynamic: rule.IsDynamic, - Domains: rule.Domains.ToPunycodeList(), - PolicyID: []byte(rule.PolicyID), - RouteID: string(rule.RouteID), - } - } - - return result -} - -// getProtoDirection converts the direction to proto.RuleDirection. -func getProtoDirection(direction int) proto.RuleDirection { - if direction == types.FirewallRuleDirectionOUT { - return proto.RuleDirection_OUT - } - return proto.RuleDirection_IN -} - -// getProtoAction converts the action to proto.RuleAction. -func getProtoAction(action string) proto.RuleAction { - if action == string(types.PolicyTrafficActionDrop) { - return proto.RuleAction_DROP - } - return proto.RuleAction_ACCEPT -} - -// getProtoProtocol converts the protocol to proto.RuleProtocol. -func getProtoProtocol(protocol string) proto.RuleProtocol { - switch types.PolicyRuleProtocolType(protocol) { - case types.PolicyRuleProtocolALL: - return proto.RuleProtocol_ALL - case types.PolicyRuleProtocolTCP: - return proto.RuleProtocol_TCP - case types.PolicyRuleProtocolUDP: - return proto.RuleProtocol_UDP - case types.PolicyRuleProtocolICMP: - return proto.RuleProtocol_ICMP - default: - return proto.RuleProtocol_UNKNOWN - } -} - -// getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { - var portInfo proto.PortInfo - if rule.Port != 0 { - portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} - } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { - portInfo.PortSelection = &proto.PortInfo_Range_{ - Range: &proto.PortInfo_Range{ - Start: uint32(portRange.Start), - End: uint32(portRange.End), - }, - } - } - return &portInfo -} - // areRouteChangesAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers. func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index 388db140c..a413d545b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -14,6 +14,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -432,7 +437,7 @@ func TestCreateRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -922,7 +927,7 @@ func TestSaveRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1024,7 +1029,7 @@ func TestDeleteRoute(t *testing.T) { Enabled: true, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1071,7 +1076,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1163,7 +1168,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1250,11 +1255,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } -func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { +func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createRouterStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} @@ -1285,7 +1290,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + + am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + if err != nil { + return nil, nil, err + } + return am, updateManager, nil } func createRouterStore(t *testing.T) (store.Store, error) { @@ -1948,7 +1962,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi } func TestRouteAccountPeersUpdate(t *testing.T) { - manager, err := createRouterManager(t) + manager, updateManager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestRouteAccount(t, manager) @@ -1976,9 +1990,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) { require.NoError(t, err, "failed to create group %s", group.Name) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + updateManager.CloseChannel(context.Background(), peer1ID) }) // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index e55b33c94..bc361bbd7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -18,7 +18,7 @@ import ( ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } func TestGetSetupKeys(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) { } func TestSetupKeyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var setupKey *types.SetupKey @@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..d2220d4b4 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -24,7 +27,6 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" - nbcontext "github.com/netbirdio/netbird/management/server/context" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -46,6 +48,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 1 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +62,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -76,12 +84,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } - switch storeEngine { - case types.MysqlStoreEngine: - if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil { - return nil, err - } - case types.SqliteStoreEngine: + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -89,8 +92,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met } sql.SetMaxOpenConns(conns) + sql.SetMaxIdleConns(conns) + sql.SetConnMaxLifetime(time.Hour) + sql.SetConnMaxIdleTime(3 * time.Minute) - log.WithContext(ctx).Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) if skipMigration { log.WithContext(ctx).Infof("skipping migration") @@ -162,7 +169,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro group.StoreGroupPeers() } - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -257,7 +264,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -280,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) + log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds()) return err } @@ -307,7 +314,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -405,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW return nil } +// ApproveAccountPeers marks all peers that currently require approval in the given account as approved. +func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) { + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true). + Update("peer_status_requires_approval", false) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error) + } + + return int(result.RowsAffected), nil +} + // SaveUsers saves the given list of users to the database. func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { if len(users) == 0 { @@ -575,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var user types.User - result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) + result := tx.Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -596,7 +612,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error @@ -774,6 +790,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -784,9 +807,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -796,70 +829,1154 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups sql.NullString + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange sql.NullString + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups sql.NullString + networkNet sql.NullString + dnsSettingsDisabledGroups sql.NullString + networkIdentifier sql.NullString + networkDns sql.NullString + networkSerial sql.NullInt64 + createdAt sql.NullTime + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(err) + } + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if networkNet.Valid { + _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net) + } + if createdAt.Valid { + account.CreatedAt = createdAt.Time + } + if dnsSettingsDisabledGroups.Valid { + _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups) + } + if networkIdentifier.Valid { + account.Network.Identifier = networkIdentifier.String + } + if networkDns.Valid { + account.Network.Dns = networkDns.String + } + if networkSerial.Valid { + account.Network.Serial = uint64(networkSerial.Int64) + } + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups.Valid { + _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) + } + if sNetworkRange.Valid { + _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups.Valid { + _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) + } + account.InitOnce() return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, + revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt, + &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if skCreatedAt.Valid { + sk.CreatedAt = skCreatedAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, + meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ( + lastLogin, createdAt sql.NullTime + sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + peerStatusLastSeen sql.NullTime + peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + ip, extraDNS, netAddr, env, flags, files, connIP []byte + metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString + metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString + metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString + locationCountryCode, locationCityName sql.NullString + locationGeoNameID sql.NullInt64 + ) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, + &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, + &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, + &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, + &locationCountryCode, &locationCityName, &locationGeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + p.CreatedAt = createdAt.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if metaHostname.Valid { + p.Meta.Hostname = metaHostname.String + } + if metaGoOS.Valid { + p.Meta.GoOS = metaGoOS.String + } + if metaKernel.Valid { + p.Meta.Kernel = metaKernel.String + } + if metaCore.Valid { + p.Meta.Core = metaCore.String + } + if metaPlatform.Valid { + p.Meta.Platform = metaPlatform.String + } + if metaOS.Valid { + p.Meta.OS = metaOS.String + } + if metaOSVersion.Valid { + p.Meta.OSVersion = metaOSVersion.String + } + if metaWtVersion.Valid { + p.Meta.WtVersion = metaWtVersion.String + } + if metaUIVersion.Valid { + p.Meta.UIVersion = metaUIVersion.String + } + if metaKernelVersion.Valid { + p.Meta.KernelVersion = metaKernelVersion.String + } + if metaSystemSerialNumber.Valid { + p.Meta.SystemSerialNumber = metaSystemSerialNumber.String + } + if metaSystemProductName.Valid { + p.Meta.SystemProductName = metaSystemProductName.String + } + if metaSystemManufacturer.Valid { + p.Meta.SystemManufacturer = metaSystemManufacturer.String + } + if locationCountryCode.Valid { + p.Location.CountryCode = locationCountryCode.String + } + if locationCityName.Valid { + p.Location.CityName = locationCityName.String + } + if locationGeoNameID.Valid { + p.Location.GeoNameID = uint(locationGeoNameID.Int64) + } + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin, createdAt sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + var createdAt, updatedAt sql.NullTime + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &createdAt, + &updatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if createdAt.Valid { + account.Onboarding.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + account.Onboarding.UpdatedAt = updatedAt.Time + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed, createdAt sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if createdAt.Valid { + pat.CreatedAt = createdAt.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte + var enabled, bidirectional sql.NullBool + var authorizedUser sql.NullString + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + if authorizedGroups != nil { + _ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups) + } + if authorizedUser.Valid { + r.AuthorizedUser = authorizedUser.String + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1050,16 +2167,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var accountNetwork types.AccountNetwork - if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -1069,16 +2183,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var peer nbpeer.Peer - result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1127,11 +2238,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var user types.User - result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) @@ -1199,8 +2307,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. @@ -1269,7 +2410,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics, false) + store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false) if err != nil { return nil, err } @@ -1289,6 +2430,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// used for tests only +func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) + if err != nil { + return nil, err + } + pool, err := connectToPgDbForTests(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} + +// used for tests only +func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 5 + config.MinConns = 1 + config.MaxConnLifetime = 30 * time.Second + config.HealthCheckPeriod = 10 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + // NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewMysqlStore(ctx, dsn, metrics, false) @@ -1312,16 +2497,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var setupKey types.SetupKey - result := tx.WithContext(ctx). + result := tx. Take(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { @@ -1335,10 +2517,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - - result := s.db.WithContext(ctx).Model(&types.SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1358,11 +2537,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var groupID string - _ = s.db.WithContext(ctx).Model(types.Group{}). + _ = s.db.Model(types.Group{}). Select("id"). Where("account_id = ? AND name = ?", accountID, "All"). Limit(1). @@ -1390,9 +2566,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer // AddPeerToGroup adds a peer to a group func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - peer := &types.GroupPeer{ AccountID: accountID, GroupID: groupID, @@ -1589,10 +2762,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1718,10 +2888,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - - result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1735,6 +2902,33 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + if s.storeEngine == types.PostgresStoreEngine { + if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set statement timeout: %w", err) + } + if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set lock timeout: %w", err) + } + } + + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + repo := s.withTx(tx) err := operation(repo) if err != nil { @@ -1742,6 +2936,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor return err } + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to re-enable FK checks: %w", err) + } + } + err = tx.Commit().Error log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) @@ -1759,6 +2961,31 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// transaction wraps a GORM transaction with MySQL-specific FK checks handling +// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora +func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + + err := fn(tx) + + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine && err == nil { + if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil { + return fmt.Errorf("failed to re-enable FK checks: %w", fkErr) + } + } + + return err + }) +} + func (s *SqlStore) GetDB() *gorm.DB { return s.db } @@ -2015,7 +3242,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } @@ -2783,36 +4010,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin return groupPeers, nil } -func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID) - } - - requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID) - } - - accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) - } - - go func() { - select { - case <-ctx.Done(): - case <-grpcCtx.Done(): - log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err()) - } - }() - return ctx, cancel -} - func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { var info types.PrimaryAccountInfo result := s.db.Model(&types.Account{}). @@ -2852,7 +4049,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i Network: &types.Network{Net: ipNet}, } - result := s.db.WithContext(ctx). + result := s.db. Model(&types.Account{}). Where(idQueryCondition, accountID). Updates(&patch) diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index d40c4664c..2e2623910 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { }) } } + +func TestSqlStore_ApproveAccountPeers(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + accountID := "test-account" + ctx := context.Background() + + account := newAccountWithId(ctx, accountID, "testuser", "example.com") + err := store.SaveAccount(ctx, account) + require.NoError(t, err) + + peers := []*nbpeer.Peer{ + { + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.netbird.cloud", + Key: "peer1-key", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer2", + AccountID: accountID, + DNSLabel: "peer2.netbird.cloud", + Key: "peer2-key", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer3", + AccountID: accountID, + DNSLabel: "peer3.netbird.cloud", + Key: "peer3-key", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: false, + LastSeen: time.Now().UTC(), + }, + }, + } + + for _, peer := range peers { + err = store.AddPeerToAccount(ctx, peer) + require.NoError(t, err) + } + + t.Run("approve all pending peers", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 2, count) + + allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range allPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID) + } + }) + + t.Run("no peers to approve", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("non-existent account", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, "non-existent") + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + }) +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..350a1da83 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,951 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + b.Fatalf("failed to create test container: %v", err) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + err := db.Migrator().DropTable(models[i]) + if err != nil { + b.Fatalf("failed to drop table: %v", err) + } + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, cleanup, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, cleanup, accountID := setupBenchmarkDB(b) + defer cleanup() + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, cleanup, accountID := setupBenchmarkDB(t) + defer cleanup() + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 21b660d96..0ec7949f9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -143,6 +143,7 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error + ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) @@ -468,6 +469,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) closeConnection := func() { cleanup() store.Close(ctx) + if store.pool != nil { + store.pool.Close() + } } return store, closeConnection, nil @@ -487,12 +491,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + if err != nil { return nil, nil, err } @@ -519,12 +529,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } - db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } + sqlDB, err := db.DB() + if err != nil { + return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB.Close() + if err != nil { return nil, nil, err } @@ -537,6 +557,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return store, cleanup, nil } +func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) { + var db *gorm.DB + var err error + + for i := range maxRetries { + switch engine { + case types.PostgresStoreEngine: + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case types.MysqlStoreEngine: + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + } + + if err == nil { + return db, nil + } + + if i < maxRetries-1 { + waitTime := time.Duration(100*(i+1)) * time.Millisecond + time.Sleep(waitTime) + } + } + + return nil, err +} + func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) @@ -544,21 +589,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func( return "", nil, fmt.Errorf("failed to create database: %v", err) } - var err error + originalDSN := dsn + cleanup := func() { + var dropDB *gorm.DB + var err error + switch engine { case types.PostgresStoreEngine: - err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error + case types.MysqlStoreEngine: - // err = killMySQLConnections(dsn, dbName) - err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error } + if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) - panic(err) } - sqlDB, _ := db.DB() - _ = sqlDB.Close() } return replaceDBName(dsn, dbName), cleanup, nil diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index d4301802f..bd7fbc235 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -16,7 +16,6 @@ type GRPCMetrics struct { meter metric.Meter syncRequestsCounter metric.Int64Counter syncRequestsBlockedCounter metric.Int64Counter - syncRequestHighLatencyCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter loginRequestsBlockedCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter @@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), - ) - if err != nil { - return nil, err - } - loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro meter: meter, syncRequestsCounter: syncRequestsCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter, - syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, loginRequestsCounter: loginRequestsCounter, loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, @@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration // CountSyncRequestDuration counts the duration of the sync gRPC requests func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) - if duration > HighLatencyThreshold { - grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) - } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index ae27466d9..c50ed1e51 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -7,8 +7,8 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/gorilla/mux" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { //nolint ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource) - reqID := uuid.New().String() + reqID := xid.New().String() //nolint ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) @@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, r.WithContext(ctx)) + userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) + if err == nil { + if userAuth.AccountId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) + } + if userAuth.UserId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) + } + } + if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } else { diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..c43e0bb57 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,6 +8,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-multierror" @@ -15,6 +16,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -39,8 +41,22 @@ const ( // firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules. firewallRuleMinPortRangesVer = "0.48.0" + // firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules. + firewallRuleMinNativeSSHVer = "0.60.0" + + // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. + nativeSSHPortString = "22022" + nativeSSHPortNumber = 22022 + // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. + defaultSSHPortString = "22" + defaultSSHPortNumber = 22 ) +type supportedFeatures struct { + nativeSSH bool + portRanges bool +} + type LookupMap map[string]struct{} // AccountMeta is a struct that contains a stripped down version of the Account object. @@ -87,6 +103,13 @@ type Account struct { NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` + + NetworkMapCache *NetworkMapBuilder `gorm:"-"` + nmapInitOnce *sync.Once `gorm:"-"` +} + +func (a *Account) InitOnce() { + a.nmapInitOnce = &sync.Once{} } // this class is used by gorm only @@ -255,9 +278,9 @@ func (a *Account) GetPeerNetworkMap( resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, metrics *telemetry.AccountManagerMetrics, + groupIDToUserIDs map[string][]string, ) *NetworkMap { start := time.Now() - peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -271,7 +294,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) + aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -301,7 +324,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -319,6 +342,8 @@ func (a *Account) GetPeerNetworkMap( OfflinePeers: expiredPeers, FirewallRules: firewallRules, RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + AuthorizedUsers: authorizedUsers, + EnableSSH: enableSSH, } if metrics != nil { @@ -890,6 +915,8 @@ func (a *Account) Copy() *Account { NetworkRouters: networkRouters, NetworkResources: networkResources, Onboarding: a.Onboarding, + NetworkMapCache: a.NetworkMapCache, + nmapInitOnce: a.nmapInitOnce, } } @@ -988,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map // GetPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer) + authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs + sshEnabled := false for _, policy := range a.Policies { if !policy.Enabled { @@ -1032,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P if peerInDestinations { generateResources(rule, sourcePeers, FirewallRuleDirectionIN) } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := groupIDToUserIDs[groupID] + if !ok { + log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID) + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } } } - return getAccumulatedResources() + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (a *Account) getAllowedUserIDs() map[string]struct{} { + users := make(map[string]struct{}) + for _, nbUser := range a.Users { + if !nbUser.IsBlocked() && !nbUser.IsServiceUser { + users[nbUser.Id] = struct{}{} + } + } + return users } // connResourcesGenerator returns generator and accumulator function which returns the result of generator calls @@ -1049,14 +1126,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer rules := make([]*FirewallRule, 0) peers := make([]*nbpeer.Peer, 0) - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &Group{} - } - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { continue @@ -1067,16 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer peersExists[peer.ID] = struct{}{} } + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + fr := FirewallRule{ PolicyID: rule.ID, PeerIP: peer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = "0.0.0.0" + Protocol: string(protocol), } ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + @@ -1098,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer } } +func policyRuleImpliesLegacySSH(rule *PolicyRule) bool { + return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges))) +} + +func portRangeIncludesSSH(portRanges []RulePortRange) bool { + for _, pr := range portRanges { + if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) { + return true + } + } + return false +} + +func portsIncludesSSH(ports []string) bool { + for _, port := range ports { + if port == defaultSSHPortString || port == nativeSSHPortString { + return true + } + } + return false +} + // getAllPeersFromGroups for given peer ID and list of groups // // Returns a list of peers from specified groups that pass specified posture checks @@ -1244,6 +1337,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID } } } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) for pID := range distPeersWithPolicy { @@ -1589,6 +1689,10 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st sourcePeers[peer] = struct{}{} } } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } } } @@ -1639,24 +1743,46 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error { return nil } +func (a *Account) GetActiveGroupUsers() map[string][]string { + allGroupID := "" + group, err := a.GetGroupAll() + if err != nil { + log.Errorf("failed to get group all: %v", err) + } else { + allGroupID = group.ID + } + groups := make(map[string][]string, len(a.GroupsG)) + for _, user := range a.Users { + if !user.IsBlocked() && !user.IsServiceUser { + for _, groupID := range user.AutoGroups { + groups[groupID] = append(groups[groupID], user.Id) + } + groups[allGroupID] = append(groups[allGroupID], user.Id) + } + } + return groups +} + // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { + features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) + var expanded []*FirewallRule - if len(rule.Ports) > 0 { - for _, port := range rule.Ports { - fr := base - fr.Port = port - expanded = append(expanded, &fr) - } - return expanded + for _, port := range rule.Ports { + fr := base + fr.Port = port + expanded = append(expanded, &fr) } - supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion) for _, portRange := range rule.PortRanges { + // prefer PolicyRule.Ports + if len(rule.Ports) > 0 { + break + } fr := base - if supportPortRanges { + if features.portRanges { fr.PortRange = portRange } else { // Peer doesn't support port ranges, only allow single-port ranges @@ -1668,21 +1794,67 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer expanded = append(expanded, &fr) } + if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH { + expanded = addNativeSSHRule(base, expanded) + } + return expanded } -// peerSupportsPortRanges checks if the peer version supports port ranges. -func peerSupportsPortRanges(peerVer string) bool { - if strings.Contains(peerVer, "dev") { - return true +// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured. +func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule { + shouldAdd := false + for _, fr := range expanded { + if isPortInRule(nativeSSHPortString, 22022, fr) { + return expanded + } + if isPortInRule(defaultSSHPortString, 22, fr) { + shouldAdd = true + } + } + if !shouldAdd { + return expanded } - meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) - return err == nil && meetMinVer + fr := base + fr.Port = nativeSSHPortString + return append(expanded, &fr) +} + +func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool { + return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End) +} + +// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support. +// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled +// in both management and client, we indicate to add the native port. +func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool { + return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP +} + +// peerSupportedFirewallFeatures checks if the peer version supports port ranges. +func peerSupportedFirewallFeatures(peerVer string) supportedFeatures { + if strings.Contains(peerVer, "dev") { + return supportedFeatures{true, true} + } + + var features supportedFeatures + + meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer) + features.nativeSSH = err == nil && meetMinVer + + if features.nativeSSH { + features.portRanges = true + } else { + meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) + features.portRanges = err == nil && meetMinVer + } + + return features } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1865,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..2c9f2428d 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -839,12 +839,466 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, sourcePeers, 2, "expected source peers don't match") } +func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + rule *PolicyRule + base FirewallRule + expectedPorts []string + }{ + { + name: "adds port 22022 when SSH enabled on modern peer with port 22", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "adds port 22022 once when port 22 is duplicated within policy", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "80", "22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "80", "22", "22022"}, + }, + { + name: "does not add 22022 for peer with old version", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when SSHEnabled is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: false, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when ServerSSHAllowed is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: false}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 for UDP protocol", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolUDP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "udp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when port 22 not in rule", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80", "443"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"80", "443"}, + }, + { + name: "does not duplicate 22022 when already present", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "22022"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "does not duplicate 22022 when already within a port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 32000}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-32000"}, + }, + { + name: "adds 22022 when port 22 in port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "22022"}, + }, + { + name: "adds single 22022 once when port 22 in multiple port ranges", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}, {Start: 10, End: 100}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "10-100", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.50.0-dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "development suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPortsAndRanges(tt.base, tt.rule, tt.peer) + + var ports []string + for _, fr := range result { + if fr.Port != "" { + ports = append(ports, fr.Port) + } else if fr.PortRange.Start > 0 { + ports = append(ports, fmt.Sprintf("%d-%d", fr.PortRange.Start, fr.PortRange.End)) + } + } + + assert.Equal(t, tt.expectedPorts, ports, "expanded ports should match expected") + }) + } +} + +func Test_GetActiveGroupUsers(t *testing.T) { + tests := []struct { + name string + account *Account + expected map[string][]string + }{ + { + name: "all users are active", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1", "user2"}, + "group3": {"user2"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "some users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1", "group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1"}, + "group3": {"user3"}, + "": {"user1", "user3"}, + }, + }, + { + name: "all users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: true, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2"}, + Blocked: true, + }, + }, + }, + expected: map[string][]string{}, + }, + { + name: "user with no auto groups", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user2"}, + "": {"user1", "user2"}, + }, + }, + { + name: "empty account", + account: &Account{ + Users: map[string]*User{}, + }, + expected: map[string][]string{}, + }, + { + name: "multiple users in same group", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user2", "user3"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "user in multiple groups with blocked users", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2", "group3"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1", "group2"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1"}, + "group2": {"user1"}, + "group3": {"user1", "user3"}, + "": {"user1", "user3"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetActiveGroupUsers() + + // Check that the number of groups matches + assert.Equal(t, len(tt.expected), len(result), "number of groups should match") + + // Check each group's users + for groupID, expectedUsers := range tt.expected { + actualUsers, exists := result[groupID] + assert.True(t, exists, "group %s should exist in result", groupID) + assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID) + } + + // Ensure no extra groups in result + for groupID := range result { + _, exists := tt.expected[groupID] + assert.True(t, exists, "unexpected group %s in result", groupID) + } + }) + } +} + func Test_FilterZoneRecordsForPeers(t *testing.T) { tests := []struct { name string peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +1311,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +1345,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +1380,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +1391,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/management/server/types/holder.go b/management/server/types/holder.go new file mode 100644 index 000000000..3996db2b6 --- /dev/null +++ b/management/server/types/holder.go @@ -0,0 +1,43 @@ +package types + +import ( + "context" + "sync" +) + +type Holder struct { + mu sync.RWMutex + accounts map[string]*Account +} + +func NewHolder() *Holder { + return &Holder{ + accounts: make(map[string]*Account), + } +} + +func (h *Holder) GetAccount(id string) *Account { + h.mu.RLock() + defer h.mu.RUnlock() + return h.accounts[id] +} + +func (h *Holder) AddAccount(account *Account) { + h.mu.Lock() + defer h.mu.Unlock() + h.accounts[account.Id] = account +} + +func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { + h.mu.Lock() + defer h.mu.Unlock() + if acc, ok := h.accounts[id]; ok { + return acc, nil + } + account, err := accGetter(context.Background(), id) + if err != nil { + return nil, err + } + h.accounts[id] = account + return account, nil +} diff --git a/management/server/types/network.go b/management/server/types/network.go index ffc019565..d3708d80a 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -38,6 +38,8 @@ type NetworkMap struct { FirewallRules []*FirewallRule RoutesFirewallRules []*RouteFirewallRule ForwardingRules []*ForwardingRule + AuthorizedUsers map[string]map[string]struct{} + EnableSSH bool } func (nm *NetworkMap) Merge(other *NetworkMap) { diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go new file mode 100644 index 000000000..c1099726f --- /dev/null +++ b/management/server/types/networkmap.go @@ -0,0 +1,58 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { + if a.NetworkMapCache != nil { + return + } + a.nmapInitOnce.Do(func() { + a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) + }) +} + +func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} + +func (a *Account) GetPeerNetworkMapExp( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + a.initNetworkMapBuilder(validatedPeers) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) +} + +func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerAddedIncremental(peerId) +} + +func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerDeleted(peerId) +} + +func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.UpdatePeer(peer) +} + +func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go new file mode 100644 index 000000000..913094e4c --- /dev/null +++ b/management/server/types/networkmap_golden_test.go @@ -0,0 +1,1069 @@ +package types_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// update flag is used to update the golden file. +// example: go test ./... -v -update +// var update = flag.Bool("update", false, "update golden files") + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. + expiredPeerID = "peer-98" // This peer will be online but with an expired session. + offlinePeerID = "peer-99" // This peer will be completely offline. + routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. + testAccountID = "account-golden-test" +) + +func TestGetPeerNetworkMap_Golden(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") +} + +func BenchmarkGetPeerNetworkMap(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + b.ResetTimer() + b.Run("old builder", func(b *testing.B) { + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) + } + } + }) + b.ResetTimer() + b.Run("new builder", func(b *testing.B) { + for range b.N { + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + for _, peerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + + t.Log("Update golden file with new peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newPeerID) + require.NoError(t, err, "error adding peer to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") + t.Log("Update golden file with OnPeerAdded...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: net.IP{100, 64, 1, 1}, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + } + + account.Peers[newPeerID] = newPeer + account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) + account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) + validatedPeersMap[newPeerID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + + t.Log("Update golden file with new router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newRouterID) + require.NoError(t, err, "error adding router to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newRouterID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedPeerID := "peer-25" // peer from devs group + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" // devs group peer + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedPeerID) + require.NoError(t, err, "error deleting peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") + t.Log("Update golden file with OnPeerDeleted...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedRouterID) + require.NoError(t, err, "error deleting routing peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + + t.Log("Update golden file with deleted router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err) + + require.JSONEq(t, string(expectedJSON), string(jsonData), + "network map after deleting router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + deletedPeerID := "peer-25" + + delete(account.Peers, deletedPeerID) + account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + delete(validatedPeersMap, deletedPeerID) + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + b.ResetTimer() + b.Run("old builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) + } + } + }) + + b.ResetTimer() + b.Run("new builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerDeleted(deletedPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { + for _, peer := range networkMap.Peers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + for _, peer := range networkMap.OfflinePeers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + + sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) + sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) + sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) + + sort.Slice(networkMap.FirewallRules, func(i, j int) bool { + r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] + if r1.PeerIP != r2.PeerIP { + return r1.PeerIP < r2.PeerIP + } + if r1.Protocol != r2.Protocol { + return r1.Protocol < r2.Protocol + } + if r1.Direction != r2.Direction { + return r1.Direction < r2.Direction + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + return r1.Port < r2.Port + }) + + sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { + r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] + if r1.RouteID != r2.RouteID { + return r1.RouteID < r2.RouteID + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + if r1.Destination != r2.Destination { + return r1.Destination < r2.Destination + } + if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { + if r1.SourceRanges[0] != r2.SourceRanges[0] { + return r1.SourceRanges[0] < r2.SourceRanges[0] + } + } + return r1.Port < r2.Port + }) + + for _, ranges := range networkMap.RoutesFirewallRules { + sort.Slice(ranges.SourceRanges, func(i, j int) bool { + return ranges.SourceRanges[i] < ranges.SourceRanges[j] + }) + } +} + +func createTestAccountWithEntities() *types.Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + + } + + groups := map[string]*types.Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*types.Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*types.PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &types.Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &types.Network{ + Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*dns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go new file mode 100644 index 000000000..5790f1646 --- /dev/null +++ b/management/server/types/networkmapbuilder.go @@ -0,0 +1,2018 @@ +package types + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +const ( + allPeers = "0.0.0.0" + allWildcard = "0.0.0.0/0" + v6AllWildcard = "::/0" + fw = "fw:" + rfw = "route-fw:" +) + +type NetworkMapCache struct { + globalRoutes map[route.ID]*route.Route + globalRules map[string]*FirewallRule //ruleId + globalRouteRules map[string]*RouteFirewallRule //ruleId + globalPeers map[string]*nbpeer.Peer + + groupToPeers map[string][]string + peerToGroups map[string][]string + policyToRules map[string][]*PolicyRule //policyId + groupToPolicies map[string][]*Policy + groupToRoutes map[string][]*route.Route + peerToRoutes map[string][]*route.Route + + peerACLs map[string]*PeerACLView + peerRoutes map[string]*PeerRoutesView + peerDNS map[string]*nbdns.Config + + resourceRouters map[string]map[string]*routerTypes.NetworkRouter + resourcePolicies map[string][]*Policy + + globalResources map[string]*resourceTypes.NetworkResource // resourceId + + acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info + noACGRoutes map[route.ID]*RouteOwnerInfo + + mu sync.RWMutex +} + +type RouteOwnerInfo struct { + PeerID string + RouteID route.ID +} + +type PeerACLView struct { + ConnectedPeerIDs []string + FirewallRuleIDs []string +} + +type PeerRoutesView struct { + OwnRouteIDs []route.ID + NetworkResourceIDs []route.ID + InheritedRouteIDs []route.ID + RouteFirewallRuleIDs []string +} + +type NetworkMapBuilder struct { + account atomic.Pointer[Account] + cache *NetworkMapCache + validatedPeers map[string]struct{} +} + +func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { + builder := &NetworkMapBuilder{ + cache: &NetworkMapCache{ + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), + peerACLs: make(map[string]*PeerACLView), + peerRoutes: make(map[string]*PeerRoutesView), + peerDNS: make(map[string]*nbdns.Config), + globalResources: make(map[string]*resourceTypes.NetworkResource), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + }, + validatedPeers: make(map[string]struct{}), + } + builder.account.Store(account) + maps.Copy(builder.validatedPeers, validatedPeers) + + builder.initialBuild(account) + + return builder +} + +func (b *NetworkMapBuilder) initialBuild(account *Account) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + start := time.Now() + + b.buildGlobalIndexes(account) + + resourceRouters := account.GetResourceRoutersMap() + resourcePolicies := account.GetResourcePoliciesMap() + b.cache.resourceRouters = resourceRouters + b.cache.resourcePolicies = resourcePolicies + + for peerID := range account.Peers { + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + } + + log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) +} + +func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { + clear(b.cache.globalPeers) + clear(b.cache.groupToPeers) + clear(b.cache.peerToGroups) + clear(b.cache.policyToRules) + clear(b.cache.groupToPolicies) + clear(b.cache.globalRoutes) + clear(b.cache.globalRules) + clear(b.cache.globalRouteRules) + clear(b.cache.globalResources) + clear(b.cache.groupToRoutes) + clear(b.cache.peerToRoutes) + clear(b.cache.acgToRoutes) + clear(b.cache.noACGRoutes) + + maps.Copy(b.cache.globalPeers, account.Peers) + + for groupID, group := range account.Groups { + peersCopy := make([]string, len(group.Peers)) + copy(peersCopy, group.Peers) + b.cache.groupToPeers[groupID] = peersCopy + + for _, peerID := range group.Peers { + b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) + } + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + b.cache.policyToRules[policy.ID] = policy.Rules + + affectedGroups := make(map[string]struct{}) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, groupID := range rule.Sources { + affectedGroups[groupID] = struct{}{} + } + for _, groupID := range rule.Destinations { + affectedGroups[groupID] = struct{}{} + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + groupId := rule.SourceResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + groupId := rule.DestinationResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) + } + } + + for groupID := range affectedGroups { + b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) + } + } + + for _, resource := range account.NetworkResources { + if !resource.Enabled { + continue + } + b.cache.globalResources[resource.ID] = resource + } + + for _, r := range account.Routes { + if !r.Enabled { + continue + } + for _, groupID := range r.PeerGroups { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } +} + +func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { + peer := account.GetPeer(peerID) + if peer == nil { + return + } + + allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + + isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) + + var emptyExpiredPeers []*nbpeer.Peer + finalAllPeers := b.addNetworksRoutingPeers( + networkResourcesRoutes, + peer, + allPotentialPeers, + emptyExpiredPeers, + isRouter, + sourcePeers, + ) + + view := &PeerACLView{ + ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), + FirewallRuleIDs: make([]string, 0, len(firewallRules)), + } + + for _, p := range finalAllPeers { + view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) + } + + for _, rule := range firewallRules { + ruleID := b.generateFirewallRuleID(rule) + view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) + b.cache.globalRules[ruleID] = rule + } + + b.cache.peerACLs[peerID] = view +} + +func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, + validatedPeersMap map[string]struct{}, +) ([]*nbpeer.Peer, []*FirewallRule) { + peerID := peer.ID + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + fwRules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + for _, group := range peerGroups { + policies := b.cache.groupToPolicies[group] + for _, policy := range policies { + rules := b.cache.policyToRules[policy.ID] + for _, rule := range rules { + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peerInSources = rule.SourceResource.ID == peerID + } else { + peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peerInDestinations = rule.DestinationResource.ID == peerID + } else { + peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) + } + + if !peerInSources && !peerInDestinations { + continue + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peer := account.GetPeer(rule.SourceResource.ID) + if peer != nil { + sourcePeers = []*nbpeer.Peer{peer} + } + } else { + sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peer := account.GetPeer(rule.DestinationResource.ID) + if peer != nil { + destinationPeers = []*nbpeer.Peer{peer} + } + } else { + destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + } + } + + return peers, fwRules +} + +func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { + for _, groupID := range groupIDs { + if _, exists := peerGroupsMap[groupID]; exists { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, + excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, +) []*nbpeer.Peer { + ctx := context.Background() + uniquePeers := make(map[string]*nbpeer.Peer) + + for _, groupID := range groupIDs { + peerIDs := b.cache.groupToPeers[groupID] + for _, peerID := range peerIDs { + if peerID == excludePeerID { + continue + } + + if _, ok := validatedPeersMap[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + if len(postureChecksIDs) > 0 { + if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { + continue + } + } + + uniquePeers[peerID] = peer + } + } + + result := make([]*nbpeer.Peer, 0, len(uniquePeers)) + for _, peer := range uniquePeers { + result = append(result, peer) + } + + return result +} + +func (b *NetworkMapBuilder) generateResourcescached( + account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, +) { + isAll := false + if allGroup, err := account.GetGroupAll(); err == nil { + isAll = (len(allGroup.Peers) - 1) == len(groupPeers) + } + + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + *peers = append(*peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = allPeers + } + + var s strings.Builder + s.WriteString(rule.ID) + s.WriteString(fr.PeerIP) + s.WriteString(strconv.Itoa(direction)) + s.WriteString(fr.Protocol) + s.WriteString(fr.Action) + s.WriteString(strings.Join(rule.Ports, ",")) + + ruleID := s.String() + + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + *rules = append(*rules, &fr) + continue + } + + *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } +} + +func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { + ctx := context.Background() + peerID := peer.ID + + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, resource := range b.cache.globalResources { + + networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] + resourcePolicies := b.cache.resourcePolicies[resource.ID] + if len(resourcePolicies) == 0 { + continue + } + + isRouterForThisResource := false + + if networkRoutingPeers != nil { + if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { + isRoutingPeer = true + isRouterForThisResource = true + if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + + hasAccessAsClient := false + if !isRouterForThisResource { + for _, policy := range resourcePolicies { + if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + hasAccessAsClient = true + break + } + } + } + } + + if hasAccessAsClient && networkRoutingPeers != nil { + for routerPeerID, router := range networkRoutingPeers { + if router.Enabled { + if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + } + + if isRouterForThisResource { + for _, policy := range resourcePolicies { + var peersWithAccess []*nbpeer.Peer + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peersWithAccess = []*nbpeer.Peer{peer} + } else { + peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) + } + for _, p := range peersWithAccess { + allSourcePeers[p.ID] = struct{}{} + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (b *NetworkMapBuilder) createNetworkResourceRoutes( + resource *resourceTypes.NetworkResource, routerPeerID string, + router *routerTypes.NetworkRouter, resourcePolicies []*Policy, +) *route.Route { + if len(resourcePolicies) > 0 { + peer := b.cache.globalPeers[routerPeerID] + if peer != nil { + return resource.ToRoute(peer, router) + } + } + return nil +} + +func (b *NetworkMapBuilder) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { + ctx := context.Background() + peer := account.GetPeer(peerID) + if peer == nil { + return + } + resourcePolicies := b.cache.resourcePolicies + + view := &PeerRoutesView{ + OwnRouteIDs: make([]route.ID, 0), + NetworkResourceIDs: make([]route.ID, 0), + RouteFirewallRuleIDs: make([]string, 0), + } + + enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) + for _, rt := range enabledRoutes { + if rt.PeerID != "" && rt.PeerID != peerID { + if b.cache.globalPeers[rt.PeerID] == nil { + continue + } + } + + view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + aclView := b.cache.peerACLs[peerID] + if aclView != nil { + peerRoutesMembership := make(LookupMap) + for _, r := range append(enabledRoutes, disabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(LookupMap) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, aclPeerID := range aclView.ConnectedPeerIDs { + if aclPeerID == peerID { + continue + } + activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) + groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) + haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + + for _, inheritedRoute := range haFilteredRoutes { + view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) + b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute + } + } + } + + _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) + + for _, rt := range networkResourcesRoutes { + view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) + b.updateACGIndexForPeer(peerID, allRoutes) + + routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) + for _, rule := range routeFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + + if len(networkResourcesRoutes) > 0 { + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + for _, rule := range networkResourceFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + } + + b.cache.peerRoutes[peerID] = view +} + +func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == peerID { + delete(b.cache.noACGRoutes, routeID) + } + } + + for _, rt := range routes { + if !rt.Enabled { + continue + } + + if len(rt.AccessControlGroups) == 0 { + b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } else { + for _, acg := range rt.AccessControlGroups { + if b.cache.acgToRoutes[acg] == nil { + b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) + } + + b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } + } + } +} + +func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peer := b.cache.globalPeers[peerID] + if peer == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + // maybe here is some mess - here we store peer key (see comment below) + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + peerGroups := b.cache.peerToGroups[peerID] + for _, groupID := range peerGroups { + groupRoutes := b.cache.groupToRoutes[groupID] + for _, r := range groupRoutes { + newPeerRoute := r.Copy() + // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes + newPeerRoute.Peer = peerID + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) + takeRoute(newPeerRoute, peerID) + } + } + for _, r := range b.cache.peerToRoutes[peerID] { + takeRoute(r.Copy(), peerID) + } + return enabledRoutes, disabledRoutes +} + +func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) + for _, route := range enabledRoutes { + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := b.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) + + rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make(map[string]*Policy) + + for _, groupID := range accessControlGroups { + candidatePolicies := b.cache.groupToPolicies[groupID] + + for _, policy := range candidatePolicies { + if _, found := routePolicies[policy.ID]; found { + continue + } + policyRules := b.cache.policyToRules[policy.ID] + for _, rule := range policyRules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies[policy.ID] = policy + break + } + } + } + } + + return maps.Values(routePolicies) +} + +func (b *NetworkMapBuilder) getRouteFirewallRules( + peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, + distributionPeers map[string]struct{}, account *Account, +) []*RouteFirewallRule { + ctx := context.Background() + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) + + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (b *NetworkMapBuilder) getRulePeers( + rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, + validatedPeersMap map[string]struct{}, account *Account, +) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + + for _, id := range rule.Sources { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := b.cache.globalPeers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { + peerGroups := b.cache.peerToGroups[peerID] + checkGroups := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + checkGroups[groupID] = struct{}{} + } + + dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) + dnsConfig := &nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) + } + + b.cache.peerDNS[peerID] = dnsConfig +} + +func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { + + enabled := true + for _, groupID := range account.DNSSettings.DisabledManagementGroups { + _, found := checkGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := checkGroups[gID] + if found { + peer := b.cache.globalPeers[peerID] + if !peerIsNameserver(peer, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { + b.account.Store(account) +} + +func (b *NetworkMapBuilder) GetPeerNetworkMap( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + account := b.account.Load() + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm +} + +func (b *NetworkMapBuilder) assembleNetworkMap( + account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, +) *NetworkMap { + + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, peerID := range aclView.ConnectedPeerIDs { + if _, ok := validatedPeers[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if account.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, peer) + } else { + peersToConnect = append(peersToConnect, peer) + } + } + + var routes []*route.Route + allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) + + for _, routeID := range allRouteIDs { + if route := b.cache.globalRoutes[routeID]; route != nil { + routes = append(routes, route) + } + } + + var firewallRules []*FirewallRule + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil { + firewallRules = append(firewallRules, rule) + } + } + + var routesFirewallRules []*RouteFirewallRule + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + + finalDNSConfig := *dnsConfig + if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + var zones []nbdns.CustomZone + records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: customZone.Domain, + Records: records, + }) + finalDNSConfig.CustomZones = zones + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: account.Network.Copy(), + Routes: routes, + DNSConfig: finalDNSConfig, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } +} + +func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { + var s strings.Builder + s.WriteString(fw) + s.WriteString(rule.PolicyID) + s.WriteRune(':') + s.WriteString(rule.PeerIP) + s.WriteRune(':') + s.WriteString(strconv.Itoa(rule.Direction)) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(rule.Port) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) + s.WriteRune('-') + s.WriteString(strconv.Itoa(int(rule.PortRange.End))) + return s.String() +} + +func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { + var s strings.Builder + s.WriteString(rfw) + s.WriteString(string(rule.RouteID)) + s.WriteRune(':') + s.WriteString(rule.Destination) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(strings.Join(rule.SourceRanges, ",")) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.Port))) + return s.String() +} + +func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(peerGroups, groupID) { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { + for _, r := range account.Routes { + if !r.Enabled { + continue + } + + if r.PeerID == peerID { + return true + } + + if peer := b.cache.globalPeers[peerID]; peer != nil { + if r.Peer == peer.Key && r.PeerID == "" { + return true + } + } + } + + routers := account.GetResourceRoutersMap() + for _, networkRouters := range routers { + if router, exists := networkRouters[peerID]; exists && router.Enabled { + return true + } + } + + return false +} + +type ViewDelta struct { + AddedPeerIDs []string + RemovedPeerIDs []string + AddedRuleIDs []string + RemovedRuleIDs []string +} + +func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { + tt := time.Now() + account := b.account.Load() + peer := account.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer %s not found in account", peerID) + } + + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) + + b.validatedPeers[peerID] = struct{}{} + + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) + + b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) + + log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) + + return nil +} + +func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { + peerGroups := make([]string, 0) + + for groupID, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { + b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) + } + peerGroups = append(peerGroups, groupID) + } + } + + b.cache.peerToGroups[peerID] = peerGroups + + for _, r := range account.Routes { + if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { + continue + } + for _, groupID := range r.PeerGroups { + if !slices.Contains(b.cache.groupToRoutes[groupID], r) { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } + b.cache.globalRoutes[r.ID] = r + } + + return peerGroups +} + +func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) + + if b.isPeerRouter(account, newPeerID) { + affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) + for affectedPeerID := range affectedByRoutes { + if affectedPeerID == newPeerID { + continue + } + if _, exists := updates[affectedPeerID]; !exists { + updates[affectedPeerID] = &PeerUpdateDelta{ + PeerID: affectedPeerID, + RebuildRoutesView: true, + } + } else { + updates[affectedPeerID].RebuildRoutesView = true + } + } + } + + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { + affected := make(map[string]struct{}) + enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) + + for _, route := range enabledRoutes { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + + for _, peerGroupID := range route.PeerGroups { + if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + } + + for _, route := range account.Routes { + if !route.Enabled { + continue + } + + routerInPeerGroups := false + for _, peerGroupID := range route.PeerGroups { + if slices.Contains(routerGroups, peerGroupID) { + routerInPeerGroups = true + break + } + } + + if routerInPeerGroups { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + affected[peerID] = struct{}{} + } + } + } + } + } + + return affected +} + +func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { + updates := make(map[string]*PeerUpdateDelta) + ctx := context.Background() + + groupAllLn := 0 + if allGroup, err := account.GetGroupAll(); err == nil { + groupAllLn = len(allGroup.Peers) - 1 + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return updates + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { + peerInSources = true + } else { + peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { + peerInDestinations = true + } else { + peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) + } + + if peerInSources { + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } + } + + if peerInDestinations { + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } + } + + if rule.Bidirectional { + if peerInSources { + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } + } + if peerInDestinations { + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } + } + } + } + } + + b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) + + b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) + + b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) + + return updates +} + +func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( + ctx context.Context, account *Account, newPeerID string, + updates map[string]*PeerUpdateDelta, +) { + resourceRouters := b.cache.resourceRouters + + for networkID, routers := range resourceRouters { + router, isRouter := routers[newPeerID] + if !isRouter || !router.Enabled { + continue + } + + for _, resource := range b.cache.globalResources { + if resource.NetworkID != networkID { + continue + } + + policies := b.cache.resourcePolicies[resource.ID] + if len(policies) == 0 { + continue + } + + peersWithAccess := make(map[string]struct{}) + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + groupPeers := b.cache.groupToPeers[sourceGroup] + for _, peerID := range groupPeers { + if peerID == newPeerID { + continue + } + + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + peersWithAccess[peerID] = struct{}{} + } + } + } + } + + for peerID := range peersWithAccess { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + } + updates[peerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } + } +} + +func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( + newPeerID string, newPeer *nbpeer.Peer, + peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + processedPeerRoutes := make(map[string]map[route.ID]struct{}) + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == newPeerID { + continue + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + + for _, acg := range peerGroups { + routeInfos := b.cache.acgToRoutes[acg] + if routeInfos == nil { + continue + } + + for routeID, info := range routeInfos { + if info.PeerID == newPeerID { + continue + } + + if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { + if _, processed := processedRoutes[routeID]; processed { + continue + } + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + } +} + +func (b *NetworkMapBuilder) addRouteFirewallUpdate( + updates map[string]*PeerUpdateDelta, peerID string, + routeID string, sourceIP string, +) { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), + } + updates[peerID] = delta + } + + for _, existing := range delta.UpdateRouteFirewallRules { + if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { + return + } + } + + delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ + RuleID: routeID, + AddSourceIP: sourceIP, + }) +} + +func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( + ctx context.Context, account *Account, newPeerID string, + newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + for _, resource := range b.cache.globalResources { + resourcePolicies := b.cache.resourcePolicies + resourceRouters := b.cache.resourceRouters + + policies := resourcePolicies[resource.ID] + peerHasAccess := false + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + if slices.Contains(peerGroups, sourceGroup) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { + peerHasAccess = true + break + } + } + } + + if peerHasAccess { + break + } + } + + if !peerHasAccess { + continue + } + + networkRouters := resourceRouters[resource.NetworkID] + for routerPeerID, router := range networkRouters { + if !router.Enabled || routerPeerID == newPeerID { + continue + } + + delta := updates[routerPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: routerPeerID, + } + updates[routerPeerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } +} + +type PeerUpdateDelta struct { + PeerID string + AddConnectedPeer string + AddFirewallRules []*FirewallRuleDelta + AddRoutes []route.ID + UpdateRouteFirewallRules []*RouteFirewallRuleUpdate + UpdateDNS bool + RebuildRoutesView bool +} +type FirewallRuleDelta struct { + Rule *FirewallRule + RuleID string + Direction int +} + +type RouteFirewallRuleUpdate struct { + RuleID string + AddSourceIP string +} + +func (b *NetworkMapBuilder) addUpdateForPeersInGroups( + updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, + rule *PolicyRule, direction int, allGroupLn int, +) { + for _, groupID := range groupIDs { + peers := b.cache.groupToPeers[groupID] + cnt := 0 + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + cnt++ + } + all := false + if allGroupLn > 0 && cnt == allGroupLn { + all = true + } + newPeer := b.cache.globalPeers[newPeerID] + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + targetPeer := b.cache.globalPeers[peerID] + if targetPeer == nil { + continue + } + + peerIPForRule := fr.PeerIP + if all { + peerIPForRule = allPeers + } + + b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) + } + } +} + +func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, +) { + if targetPeerID == newPeerID { + return + } + + if _, ok := b.validatedPeers[targetPeerID]; !ok { + return + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return + } + + targetPeer := b.cache.globalPeers[targetPeerID] + if targetPeer == nil { + return + } + + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) +} + +func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, +) { + delta := updates[targetPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: targetPeerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[targetPeerID] = delta + } + + baseRule.PeerIP = peerIP + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(baseRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: baseRule, + RuleID: ruleID, + Direction: direction, + }) + } +} + +func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { + if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + } + + for _, ruleDelta := range delta.AddFirewallRules { + b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule + + if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { + aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) + } + } + } + } + + if delta.RebuildRoutesView { + b.buildPeerRoutesView(account, peerID) + } else if len(delta.UpdateRouteFirewallRules) > 0 { + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) + } + } + + if delta.UpdateDNS { + b.buildPeerDNSView(account, peerID) + } +} + +func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { + for _, update := range updates { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + rule := b.cache.globalRouteRules[ruleID] + if rule == nil { + continue + } + + if string(rule.RouteID) == update.RuleID { + if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { + break + } + + sourceIP := update.AddSourceIP + + if strings.Contains(sourceIP, ":") { + sourceIP += "/128" // IPv6 + } else { + sourceIP += "/32" // IPv4 + } + + if !slices.Contains(rule.SourceRanges, sourceIP) { + rule.SourceRanges = append(rule.SourceRanges, sourceIP) + } + break + } + } + } +} + +func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.account.Load() + + deletedPeer := b.cache.globalPeers[peerID] + if deletedPeer == nil { + return fmt.Errorf("peer %s not found in cache", peerID) + } + + deletedPeerKey := deletedPeer.Key + peerGroups := b.cache.peerToGroups[peerID] + peerIP := deletedPeer.IP.String() + + log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) + + delete(b.validatedPeers, peerID) + + routesToDelete := []route.ID{} + + for routeID, r := range account.Routes { + if r.Peer != deletedPeerKey && r.PeerID != peerID { + continue + } + if len(r.PeerGroups) == 0 { + routesToDelete = append(routesToDelete, routeID) + continue + } + newPeerAssigned := false + for _, groupID := range r.PeerGroups { + candidatePeerIDs := b.cache.groupToPeers[groupID] + for _, candidatePeerID := range candidatePeerIDs { + if candidatePeerID == peerID { + continue + } + if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { + r.Peer = candidatePeer.Key + r.PeerID = candidatePeerID + newPeerAssigned = true + break + } + } + if newPeerAssigned { + break + } + } + + if !newPeerAssigned { + routesToDelete = append(routesToDelete, routeID) + } + } + + for _, routeID := range routesToDelete { + delete(account.Routes, routeID) + } + + delete(b.cache.peerACLs, peerID) + delete(b.cache.peerRoutes, peerID) + delete(b.cache.peerDNS, peerID) + + delete(b.cache.globalPeers, peerID) + + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for _, groupID := range peerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { + return id == peerID + }) + } + } + delete(b.cache.peerToGroups, peerID) + + affectedPeers := make(map[string]struct{}) + + for _, r := range account.Routes { + for _, groupID := range r.Groups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + + for _, groupID := range r.PeerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + } + + for affectedPeerID := range affectedPeers { + if affectedPeerID == peerID { + continue + } + b.buildPeerRoutesView(account, affectedPeerID) + } + + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + for affectedPeerID, updates := range peerDeletionUpdates { + b.applyDeletionUpdates(affectedPeerID, updates) + } + + b.cleanupUnusedRules() + + log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) + + return nil +} + +func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( + deletedPeerID string, + peerIP string, +) map[string]*PeerDeletionUpdate { + + affected := make(map[string]*PeerDeletionUpdate) + + for peerID, aclView := range b.cache.peerACLs { + if peerID == deletedPeerID { + continue + } + + if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + continue + } + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } + } + + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { + affected[peerID].RemoveFirewallRuleIDs = append( + affected[peerID].RemoveFirewallRuleIDs, + ruleID, + ) + } + } + } + + return affected +} + +type PeerDeletionUpdate struct { + RemovePeerID string + RemoveFirewallRuleIDs []string + RemoveRouteIDs []route.ID + RemoveFromSourceRanges bool + PeerIP string +} + +func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { + return id == updates.RemovePeerID + }) + + if len(updates.RemoveFirewallRuleIDs) > 0 { + aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) + }) + } + } + + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + if len(updates.RemoveRouteIDs) > 0 { + routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { + return slices.Contains(updates.RemoveRouteIDs, routeID) + }) + } + + if updates.RemoveFromSourceRanges { + b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) + } + } +} + +func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { + sourceIPv4 := peerIP + "/32" + sourceIPv6 := peerIP + "/128" + + rulesToRemove := []string{} + + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { + return source == sourceIPv4 || source == sourceIPv6 || source == peerIP + }) + + if len(rule.SourceRanges) == 0 { + rulesToRemove = append(rulesToRemove, ruleID) + } + } + } + + if len(rulesToRemove) > 0 { + routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(rulesToRemove, ruleID) + }) + } +} + +func (b *NetworkMapBuilder) cleanupUnusedRules() { + usedFirewallRules := make(map[string]struct{}) + usedRouteRules := make(map[string]struct{}) + usedRoutes := make(map[route.ID]struct{}) + + for _, aclView := range b.cache.peerACLs { + for _, ruleID := range aclView.FirewallRuleIDs { + usedFirewallRules[ruleID] = struct{}{} + } + } + + for _, routesView := range b.cache.peerRoutes { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + usedRouteRules[ruleID] = struct{}{} + } + + for _, routeID := range routesView.OwnRouteIDs { + usedRoutes[routeID] = struct{}{} + } + for _, routeID := range routesView.NetworkResourceIDs { + usedRoutes[routeID] = struct{}{} + } + } + + for ruleID := range b.cache.globalRules { + if _, used := usedFirewallRules[ruleID]; !used { + delete(b.cache.globalRules, ruleID) + } + } + + for ruleID := range b.cache.globalRouteRules { + if _, used := usedRouteRules[ruleID]; !used { + delete(b.cache.globalRouteRules, ruleID) + } + } + + for routeID := range b.cache.globalRoutes { + if _, used := usedRoutes[routeID]; !used { + delete(b.cache.globalRoutes, routeID) + } + } +} + +func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + peerStored, ok := b.cache.globalPeers[peer.ID] + if !ok { + return + } + *peerStored = *peer +} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5e86a87c6..d4e1a8816 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -23,6 +23,8 @@ const ( PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") // PolicyRuleProtocolICMP type of traffic PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") + // PolicyRuleProtocolNetbirdSSH type of traffic + PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh") ) const ( @@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) protocol = PolicyRuleProtocolUDP case "icmp": return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + case "netbird-ssh": + return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil default: return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) } diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index 2643ae45c..bb75dd555 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -80,6 +80,12 @@ type PolicyRule struct { // PortRanges a list of port ranges. PortRanges []RulePortRange `gorm:"serializer:json"` + + // AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh + AuthorizedGroups map[string][]string `gorm:"serializer:json"` + + // AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh + AuthorizedUser string } // Copy returns a copy of a policy rule @@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule { Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), PortRanges: make([]RulePortRange, len(pm.PortRanges)), + AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)), + AuthorizedUser: pm.AuthorizedUser, } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) copy(rule.PortRanges, pm.PortRanges) + for k, v := range pm.AuthorizedGroups { + rule.AuthorizedGroups[k] = make([]string, len(v)) + copy(rule.AuthorizedGroups[k], v) + } return rule } diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 6eb391cb5..da29e1d87 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -1,8 +1,8 @@ package types import ( - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // RouteFirewallRule a firewall rule applicable for a routed network. diff --git a/management/server/types/settings.go b/management/server/types/settings.go index b4afb2f5e..867e12bef 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -52,6 +52,9 @@ type Settings struct { // LazyConnectionEnabled indicates if the experimental feature is enabled or disabled LazyConnectionEnabled bool `gorm:"default:false"` + + // AutoUpdateVersion client auto-update version + AutoUpdateVersion string `gorm:"default:'disabled'"` } // Copy copies the Settings struct @@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings { LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, + AutoUpdateVersion: s.AutoUpdateVersion, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a..9d4620462 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -7,12 +7,13 @@ import ( "strings" "time" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - nbContext "github.com/netbirdio/netbird/management/server/context" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -175,9 +176,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } -// GetUser looks up a user by provided nbContext.UserAuths. +// GetUser looks up a user by provided auth.UserAuths. // Expects account to have been created already. -func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return nil, err @@ -262,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) - } - return nil } @@ -526,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( + _, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - if userHadPeers { - updateAccountPeers = true - } + updateAccountPeers = true err = transaction.SaveUser(ctx, updatedUser) if err != nil { @@ -584,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } } - if settings.GroupsPropagationEnabled && updateAccountPeers { + if updateAccountPeers { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } @@ -595,9 +590,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, isNewUser bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() + if isNewUser { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.UserCreated, nil) + }) + } + if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { @@ -621,6 +622,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac }) } + addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err) + } + + for _, group := range addedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + }) + } + + removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err) + } + for _, group := range removedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + }) + } + return eventsToStore } @@ -631,7 +661,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) + oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) if err != nil { return false, nil, nil, nil, err } @@ -667,9 +697,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact peersToExpire = userPeers } + var removedGroups, addedGroups []string if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups) + addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups) for _, peer := range userPeers { for _, groupID := range removedGroups { if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { @@ -685,30 +716,30 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. -func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, bool, error) { existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + return nil, false, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) } update.AccountID = accountID - return update, nil // use all fields from update if addIfNotExists is true + return update, true, nil // use all fields from update if addIfNotExists is true } - return nil, err + return nil, false, err } if existingUser.AccountID != accountID { - return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch") + return nil, false, status.Errorf(status.InvalidArgument, "user account ID mismatch") } - return existingUser, nil + return existingUser, false, nil } func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { @@ -935,12 +966,12 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if err != nil { return err } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) var peerIDs []string for _, peer := range peers { // nolint:staticcheck - ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key) if peer.UserID == "" { // we do not want to expire peers that are added via setup key @@ -963,11 +994,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou ) } + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } + if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs) } return nil } @@ -1013,7 +1048,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } var allErrors error - var updateAccountPeers bool for _, targetUserID := range targetUserIDs { if initiatorUserID == targetUserID { @@ -1044,19 +1078,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { allErrors = errors.Join(allErrors, err) continue } - - if userHadPeers { - updateAccountPeers = true - } - } - - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) } return allErrors @@ -1081,6 +1107,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var addPeerRemovedEvents []func() var updateAccountPeers bool + var userPeers []*nbpeer.Peer var targetUser *types.User var err error @@ -1090,7 +1117,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) + userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1113,6 +1140,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + var peerIDs []string + for _, peer := range userPeers { + peerIDs = append(peerIDs, peer.ID) + } + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { + log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) + } + for _, addPeerRemovedEvent := range addPeerRemovedEvents { addPeerRemovedEvent() } @@ -1171,7 +1206,7 @@ func validateUserInvite(invite *types.UserInfo) error { } // GetCurrentUserInfo retrieves the account's current user info and permissions -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { accountID, userID := userAuth.AccountId, userAuth.UserId user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 5920a2a33..3032ee3e8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,15 +8,17 @@ import ( "time" "github.com/google/go-cmp/cmp" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) { permissionsManager: permissionsManager, } - cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) require.NoError(t, err) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs) @@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MockIntegratedValidator{}, permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -966,7 +983,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { permissionsManager: permissionsManager, } - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: mockUserID, AccountId: mockAccountID, } @@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { permissionsManager: permissionsManager, } - cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) assert.NoError(t, err) am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) @@ -1161,7 +1178,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { } func TestDefaultAccountManager_SaveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1333,7 +1350,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1357,16 +1374,16 @@ func TestUserAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) - // Creating a new regular user should not update account peers and not send peer update + // Creating a new regular user should send peer update (as users are not filtered yet) t.Run("creating new regular user with no groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1385,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - // updating user with no linked peers should not update account peers and not send peer update + // updating user with no linked peers should update account peers and send peer update (as users are not filtered yet) t.Run("updating user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("deleting user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1468,9 +1485,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + updateManager.CloseChannel(context.Background(), peer4.ID) }) // deleting user with linked peers should update account peers and send peer update @@ -1573,33 +1590,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - userAuth nbcontext.UserAuth + userAuth auth.UserAuth expectedErr error expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { name: "owner user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account1Owner", @@ -1619,7 +1636,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "regular user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "regular-user", @@ -1638,7 +1655,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "admin user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "admin-user", @@ -1657,7 +1674,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked regular user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1678,7 +1695,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { { name: "settings blocked regular user child account", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1698,7 +1715,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked owner user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account2Owner", @@ -1748,7 +1765,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { } func TestApproveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1807,7 +1824,7 @@ func TestApproveUser(t *testing.T) { } func TestRejectUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/relay/cmd/root.go b/relay/cmd/root.go index eb2cdebf8..e7dadcfdf 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error { log.Debugf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err) } - log.Infof("server will be available on: %s", srv.InstanceURL()) + instanceURL := srv.InstanceURL() + log.Infof("server will be available on: %s", instanceURL.String()) wg.Add(1) go func() { defer wg.Done() diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index eedd62394..b54d4b33b 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,14 +6,14 @@ import ( "errors" "net" "net/http" + "net/url" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" - "github.com/netbirdio/netbird/relay/server/listener/quic" - "github.com/netbirdio/netbird/relay/server/listener/ws" + "github.com/netbirdio/netbird/relay/server" ) const ( @@ -27,7 +27,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ListenAddress() string + InstanceURL() url.URL } type HealthStatus struct { @@ -135,7 +135,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if ok := s.validateCertificate(ctx); !ok { + if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS { + status.CertificateValid = false + } + + if ok := s.validateConnection(ctx); !ok { status.Status = statusUnhealthy status.CertificateValid = false healthy = false @@ -152,32 +156,13 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { return listeners, true } -func (s *Server) validateCertificate(ctx context.Context) bool { - listenAddress := s.config.ServiceChecker.ListenAddress() - if listenAddress == "" { - log.Warn("listen address is empty") +func (s *Server) validateConnection(ctx context.Context) bool { + addr := s.config.ServiceChecker.InstanceURL() + if err := dialWS(ctx, addr); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err) return false } - dAddr := dialAddress(listenAddress) - - for _, proto := range s.config.ServiceChecker.ListenerProtocols() { - switch proto { - case ws.Proto: - if err := dialWS(ctx, dAddr); err != nil { - log.Errorf("failed to dial WebSocket listener: %v", err) - return false - } - case quic.Proto: - if err := dialQUIC(ctx, dAddr); err != nil { - log.Errorf("failed to dial QUIC listener: %v", err) - return false - } - default: - log.Warnf("unknown protocol for healthcheck: %s", proto) - return false - } - } return true } @@ -187,8 +172,9 @@ func dialAddress(listenAddress string) string { return listenAddress // fallback, might be invalid for dialing } + // When listening on all interfaces, show localhost for better readability if host == "" || host == "::" || host == "0.0.0.0" { - host = "0.0.0.0" + host = "localhost" } return net.JoinHostPort(host, port) diff --git a/relay/healthcheck/peerid/peerid.go b/relay/healthcheck/peerid/peerid.go new file mode 100644 index 000000000..cd8696817 --- /dev/null +++ b/relay/healthcheck/peerid/peerid.go @@ -0,0 +1,31 @@ +package peerid + +import ( + "crypto/sha256" + + v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" + "github.com/netbirdio/netbird/shared/relay/messages" +) + +var ( + // HealthCheckPeerID is the hashed peer ID for health check connections + HealthCheckPeerID = messages.HashID("healthcheck-agent") + + // DummyAuthToken is a structurally valid auth token for health check. + // The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload). + DummyAuthToken = createDummyToken() +) + +func createDummyToken() []byte { + token := v2.Token{ + AuthAlgo: v2.AuthAlgoHMACSHA256, + Signature: make([]byte, sha256.Size), + Payload: []byte("healthcheck"), + } + return token.Marshal() +} + +// IsHealthCheck checks if the given peer ID is the health check agent +func IsHealthCheck(peerID *messages.PeerID) bool { + return peerID != nil && *peerID == HealthCheckPeerID +} diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go deleted file mode 100644 index 1582edf7b..000000000 --- a/relay/healthcheck/quic.go +++ /dev/null @@ -1,31 +0,0 @@ -package healthcheck - -import ( - "context" - "crypto/tls" - "fmt" - "time" - - "github.com/quic-go/quic-go" - - tlsnb "github.com/netbirdio/netbird/shared/relay/tls" -) - -func dialQUIC(ctx context.Context, address string) error { - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, // Keep certificate validation enabled - NextProtos: []string{tlsnb.NBalpn}, - } - - conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{ - MaxIdleTimeout: 30 * time.Second, - KeepAlivePeriod: 10 * time.Second, - EnableDatagrams: true, - }) - if err != nil { - return fmt.Errorf("failed to connect to QUIC server: %w", err) - } - - _ = conn.CloseWithError(0, "availability check complete") - return nil -} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index 49694356c..9267096f5 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,26 +3,47 @@ package healthcheck import ( "context" "fmt" + "net/url" "github.com/coder/websocket" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay" + "github.com/netbirdio/netbird/shared/relay/messages" ) -func dialWS(ctx context.Context, address string) error { - url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath) +func dialWS(ctx context.Context, address url.URL) error { + scheme := "ws" + if address.Scheme == server.SchemeRELS { + scheme = "wss" + } + wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath) - conn, resp, err := websocket.Dial(ctx, url, nil) + conn, resp, err := websocket.Dial(ctx, wsURL, nil) if resp != nil { defer func() { - _ = resp.Body.Close() + if resp.Body != nil { + _ = resp.Body.Close() + } }() } if err != nil { return fmt.Errorf("failed to connect to websocket: %w", err) } + defer func() { + _ = conn.CloseNow() + }() + + authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken) + if err != nil { + return fmt.Errorf("failed to marshal auth message: %w", err) + } + + if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil { + return fmt.Errorf("failed to write auth message: %w", err) + } - _ = conn.Close(websocket.StatusNormalClosure, "availability check complete") return nil } diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 922369798..8c3ee1899 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } if err != nil { - return nil, err + return peerID, err } h.peerID = peerID return peerID, nil @@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) { } if err := h.validator.Validate(authPayload); err != nil { - return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) + return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) } return rawPeerID, nil diff --git a/relay/server/peer.go b/relay/server/peer.go index c47f2e960..c5ff41857 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -9,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/relay/healthcheck" - "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" + "github.com/netbirdio/netbird/shared/relay/healthcheck" + "github.com/netbirdio/netbird/shared/relay/messages" ) const ( diff --git a/relay/server/relay.go b/relay/server/relay.go index d86684937..bb355f58f 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "sync" "time" @@ -11,6 +12,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" @@ -22,7 +24,7 @@ type Config struct { TLSSupport bool AuthValidator Validator - instanceURL string + instanceURL url.URL } func (c *Config) validate() error { @@ -37,7 +39,7 @@ func (c *Config) validate() error { if err != nil { return fmt.Errorf("invalid url: %v", err) } - c.instanceURL = instanceURL + c.instanceURL = *instanceURL if c.AuthValidator == nil { return fmt.Errorf("auth validator is required") @@ -51,10 +53,11 @@ type Relay struct { metricsCancel context.CancelFunc validator Validator - store *store.Store - notifier *store.PeerNotifier - instanceURL string - preparedMsg *preparedMsg + store *store.Store + notifier *store.PeerNotifier + instanceURL url.URL + exposedAddress string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -87,15 +90,16 @@ func NewRelay(config Config) (*Relay, error) { } r := &Relay{ - metrics: m, - metricsCancel: metricsCancel, - validator: config.AuthValidator, - instanceURL: config.instanceURL, - store: store.NewStore(), - notifier: store.NewPeerNotifier(), + metrics: m, + metricsCancel: metricsCancel, + validator: config.AuthValidator, + instanceURL: config.instanceURL, + exposedAddress: config.ExposedAddress, + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } - r.preparedMsg, err = newPreparedMsg(r.instanceURL) + r.preparedMsg, err = newPreparedMsg(r.instanceURL.String()) if err != nil { metricsCancel() return nil, fmt.Errorf("prepare message: %v", err) @@ -120,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) { } peerID, err := h.handshakeReceive() if err != nil { - log.Errorf("failed to handshake: %s", err) + if peerid.IsHealthCheck(peerID) { + log.Debugf("health check connection from %s", conn.RemoteAddr()) + } else { + log.Errorf("failed to handshake: %s", err) + } if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } @@ -175,6 +183,6 @@ func (r *Relay) Shutdown(ctx context.Context) { } // InstanceURL returns the instance URL of the relay server -func (r *Relay) InstanceURL() string { +func (r *Relay) InstanceURL() url.URL { return r.instanceURL } diff --git a/relay/server/server.go b/relay/server/server.go index 4c30e7fdc..8e4333064 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/tls" + "net/url" "sync" "github.com/hashicorp/go-multierror" @@ -28,8 +29,6 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - listenAddr string - relay *Relay listeners []listener.Listener listenerMux sync.Mutex @@ -41,7 +40,7 @@ type Server struct { // // config: A Config struct containing the necessary configuration: // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used. -// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required. +// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required. // - TLSSupport: A boolean indicating whether TLS is enabled for the server. // - AuthValidator: A Validator used to authenticate peers. Required. // @@ -62,8 +61,6 @@ func NewServer(config Config) (*Server, error) { // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.listenAddr = cfg.Address - wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, @@ -123,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error { return nberrors.FormatErrorOrNil(multiErr) } -// InstanceURL returns the instance URL of the relay server. -func (r *Server) InstanceURL() string { - return r.relay.instanceURL -} - func (r *Server) ListenerProtocols() []protocol.Protocol { result := make([]protocol.Protocol, 0) @@ -139,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ListenAddress() string { - return r.listenAddr +func (r *Server) InstanceURL() url.URL { + return r.relay.InstanceURL() } diff --git a/relay/server/url.go b/relay/server/url.go index 9cbf44642..aeae1c068 100644 --- a/relay/server/url.go +++ b/relay/server/url.go @@ -6,9 +6,14 @@ import ( "strings" ) +const ( + SchemeREL = "rel" + SchemeRELS = "rels" +) + // getInstanceURL checks if user supplied a URL scheme otherwise adds to the // provided address according to TLS definition and parses the address before returning it -func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { +func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) { addr := exposedAddress split := strings.Split(exposedAddress, "://") switch { @@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { case len(split) == 1 && !tlsSupported: addr = "rel://" + exposedAddress case len(split) > 2: - return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress) } parsedURL, err := url.ParseRequestURI(addr) if err != nil { - return "", fmt.Errorf("invalid exposed address: %v", err) + return nil, fmt.Errorf("invalid exposed address: %v", err) } - if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { - return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS { + return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) } - return parsedURL.String(), nil + // Validate scheme matches TLS configuration + if tlsSupported && parsedURL.Scheme == SchemeREL { + return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL) + } + + return parsedURL, nil } diff --git a/relay/server/relay_test.go b/relay/server/url_test.go similarity index 78% rename from relay/server/relay_test.go rename to relay/server/url_test.go index 062039ab9..ca455f45a 100644 --- a/relay/server/relay_test.go +++ b/relay/server/url_test.go @@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) { {"Valid address with TLS", "example.com", true, "rels://example.com", false}, {"Valid address without TLS", "example.com", false, "rel://example.com", false}, {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false}, - {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false}, + {"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true}, {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false}, {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false}, {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false}, @@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) { if (err != nil) != tt.expectError { t.Errorf("expected error: %v, got: %v", tt.expectError, err) } - if url != tt.expectedURL { - t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url) + if !tt.expectError && url != nil && url.String() != tt.expectedURL { + t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String()) + } + if tt.expectError && url != nil { + t.Errorf("expected nil URL on error, got: %s", url.String()) } }) } diff --git a/release_files/freebsd-port-diff.sh b/release_files/freebsd-port-diff.sh new file mode 100755 index 000000000..b030b9164 --- /dev/null +++ b/release_files/freebsd-port-diff.sh @@ -0,0 +1,216 @@ +#!/bin/bash +# +# FreeBSD Port Diff Generator for NetBird +# +# This script generates the diff file required for submitting a FreeBSD port update. +# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and +# computing checksums from the Go module proxy. +# +# Usage: ./freebsd-port-diff.sh [new_version] +# Example: ./freebsd-port-diff.sh 0.60.7 +# +# If no version is provided, it fetches the latest from GitHub. + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird" +GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v" +OUTPUT_DIR="${OUTPUT_DIR:-.}" +AWK_FIRST_FIELD='{print $1}' + +fetch_all_tags() { + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \ + grep -E "^DISTVERSION=" | \ + sed 's/DISTVERSION=[[:space:]]*//' | \ + tr -d '\t ' + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + fetch_all_tags | tail -1 + return 0 +} + +fetch_ports_file() { + local filename="$1" + curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null + return 0 +} + +compute_checksums() { + local version="$1" + local tmpdir + tmpdir=$(mktemp -d) + # shellcheck disable=SC2064 + trap "rm -rf '$tmpdir'" EXIT + + echo "Downloading files from Go module proxy for v${version}..." >&2 + + local mod_file="${tmpdir}/v${version}.mod" + local zip_file="${tmpdir}/v${version}.zip" + + curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null + curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null + + if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then + echo "Error: Could not download files from Go module proxy" >&2 + return 1 + fi + + local mod_sha256 mod_size zip_sha256 zip_size + + if command -v sha256sum &>/dev/null; then + mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD") + elif command -v shasum &>/dev/null; then + mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD") + else + echo "Error: No sha256 command found" >&2 + return 1 + fi + + if [[ "$OSTYPE" == "darwin"* ]]; then + mod_size=$(stat -f%z "$mod_file") + zip_size=$(stat -f%z "$zip_file") + else + mod_size=$(stat -c%s "$mod_file") + zip_size=$(stat -c%s "$zip_file") + fi + + echo "TIMESTAMP = $(date +%s)" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}" + return 0 +} + +generate_new_makefile() { + local new_version="$1" + local old_makefile="$2" + + # Check if old version had PORTREVISION + if echo "$old_makefile" | grep -q "^PORTREVISION="; then + # Remove PORTREVISION line and update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \ + grep -v "^PORTREVISION=" + else + # Just update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" + fi + return 0 +} + +# Parse arguments +NEW_VERSION="${1:-}" + +# Auto-detect versions if not provided +OLD_VERSION=$(fetch_current_ports_version) +if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not fetch current version from FreeBSD ports" >&2 + exit 1 +fi +echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2 + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + exit 1 + fi +fi +echo "Target version: ${NEW_VERSION}" >&2 + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2 + exit 0 +fi + +echo "" >&2 + +# Fetch current files +echo "Fetching current Makefile from FreeBSD ports..." >&2 +OLD_MAKEFILE=$(fetch_ports_file "Makefile") +if [[ -z "$OLD_MAKEFILE" ]]; then + echo "Error: Could not fetch Makefile" >&2 + exit 1 +fi + +echo "Fetching current distinfo from FreeBSD ports..." >&2 +OLD_DISTINFO=$(fetch_ports_file "distinfo") +if [[ -z "$OLD_DISTINFO" ]]; then + echo "Error: Could not fetch distinfo" >&2 + exit 1 +fi + +# Generate new files +echo "Generating new Makefile..." >&2 +NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE") + +echo "Computing checksums for new version..." >&2 +NEW_DISTINFO=$(compute_checksums "$NEW_VERSION") +if [[ -z "$NEW_DISTINFO" ]]; then + echo "Error: Could not compute checksums" >&2 + exit 1 +fi + +# Create temp files for diff +TMPDIR=$(mktemp -d) +# shellcheck disable=SC2064 +trap "rm -rf '$TMPDIR'" EXIT + +mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird" + +echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile" +echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo" +echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile" +echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo" + +# Generate diff +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff" + +echo "" >&2 +echo "Generating diff..." >&2 + +# Generate diff and clean up temp paths to show standard a/b paths +(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true + +if [[ ! -s "$OUTPUT_FILE" ]]; then + echo "Error: Generated diff is empty" >&2 + exit 1 +fi + +echo "" >&2 +echo "=========================================" +echo "Diff saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Review the diff above" +echo "2. Submit to https://bugs.freebsd.org/bugzilla/" +echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content" +echo "" +echo "For FreeBSD testing (optional but recommended):" +echo " cd /usr/ports/security/netbird" +echo " patch < ${OUTPUT_FILE}" +echo " make stage && make stage-qa && make package && make install" +echo " netbird status" +echo " make deinstall" diff --git a/release_files/freebsd-port-issue-body.sh b/release_files/freebsd-port-issue-body.sh new file mode 100755 index 000000000..b7ad0f5b1 --- /dev/null +++ b/release_files/freebsd-port-issue-body.sh @@ -0,0 +1,159 @@ +#!/bin/bash +# +# FreeBSD Port Issue Body Generator for NetBird +# +# This script generates the issue body content for submitting a FreeBSD port update +# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/ +# +# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version] +# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1 +# +# If no versions are provided, the script will: +# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree) +# - Fetch NEW version from latest NetBird GitHub release tag + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile" + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + local makefile_content + makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null) + if [[ -z "$makefile_content" ]]; then + echo "Error: Could not fetch Makefile from FreeBSD ports" >&2 + return 1 + fi + echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t ' + return 0 +} + +fetch_all_tags() { + # Fetch tags from GitHub tags page (no rate limiting, no auth needed) + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + local latest + + # Fetch from GitHub tags page + latest=$(fetch_all_tags | tail -1) + + if [[ -z "$latest" ]]; then + # Fallback to GitHub API + latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \ + grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/') + fi + + if [[ -z "$latest" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + return 1 + fi + echo "$latest" + return 0 +} + +OLD_VERSION="${1:-}" +NEW_VERSION="${2:-}" + +if [[ -z "$OLD_VERSION" ]]; then + OLD_VERSION=$(fetch_current_ports_version) + if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not determine old version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2 +fi + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not determine new version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected NEW version from GitHub: $NEW_VERSION" >&2 +fi + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2 +fi + +echo "" >&2 + +OUTPUT_DIR="${OUTPUT_DIR:-.}" + +fetch_releases_between_versions() { + echo "Fetching release history from GitHub..." >&2 + + # Fetch all tags and filter to those between OLD and NEW versions + fetch_all_tags | \ + while read -r ver; do + if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \ + [[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \ + [[ "$ver" != "$OLD_VERSION" ]]; then + echo "$ver" + fi + done + return 0 +} + +generate_changelog_section() { + local releases + releases=$(fetch_releases_between_versions) + + echo "Changelogs:" + if [[ -n "$releases" ]]; then + echo "$releases" | while read -r ver; do + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}" + done + else + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}" + fi + return 0 +} + +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt" + +cat << EOF > "$OUTPUT_FILE" +BUGZILLA ISSUE DETAILS +====================== + +Severity: Affects Some People + +Summary: security/netbird: Update to ${NEW_VERSION} + +Description: +------------ +security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION} + +$(generate_changelog_section) + +Commit log: +https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION} +EOF + +echo "=========================================" +echo "Issue body saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login" +echo "2. Click 'Report an update or defect to a port'" +echo "3. Fill in:" +echo " - Severity: Affects Some People" +echo " - Summary: security/netbird: Update to ${NEW_VERSION}" +echo " - Description: Copy content from ${OUTPUT_FILE}" +echo "4. Attach diff file: netbird-${NEW_VERSION}.diff" +echo "5. Submit the bug report" diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network, diff --git a/management/server/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go similarity index 92% rename from management/server/auth/jwt/extractor.go rename to shared/auth/jwt/extractor.go index d270d0ff1..a41d5f07a 100644 --- a/management/server/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -8,7 +8,7 @@ import ( "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" ) const ( @@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string { return url } -func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { +// ToUserAuth extracts user authentication information from a JWT token +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { claims := token.Claims.(jwt.MapClaims) - userAuth := nbcontext.UserAuth{} + userAuth := auth.UserAuth{} userID, ok := claims[c.userIDClaim].(string) if !ok { @@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro return userAuth, nil } +// ToGroups extracts group information from a JWT token func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { claims := token.Claims.(jwt.MapClaims) userJWTGroups := make([]string, 0) diff --git a/management/server/auth/jwt/validator.go b/shared/auth/jwt/validator.go similarity index 100% rename from management/server/auth/jwt/validator.go rename to shared/auth/jwt/validator.go diff --git a/shared/auth/user.go b/shared/auth/user.go new file mode 100644 index 000000000..c1bae808e --- /dev/null +++ b/shared/auth/user.go @@ -0,0 +1,28 @@ +package auth + +import ( + "time" +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} diff --git a/shared/context/keys.go b/shared/context/keys.go index 5345ee214..c5b5da044 100644 --- a/shared/context/keys.go +++ b/shared/context/keys.go @@ -5,4 +5,4 @@ const ( AccountIDKey = "accountID" UserIDKey = "userID" PeerIDKey = "peerID" -) \ No newline at end of file +) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d4a9f1823..9fbe70948 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -19,6 +19,12 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" @@ -27,8 +33,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -68,7 +72,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctrl := gomock.NewController(t) @@ -111,15 +114,22 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } groupsManager := groups.NewManagerMock() - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 699617574..451578358 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -1,19 +1,20 @@ package common -// LoginFlag introduces additional login flags to the PKCE authorization request +// LoginFlag introduces additional login flags to the PKCE authorization request. +// +// # Config Values +// +// | Value | Flag | OAuth Parameters | +// |-------|----------------------|-----------------------------------------| +// | 0 | LoginFlagPromptLogin | prompt=login | +// | 1 | LoginFlagMaxAge0 | max_age=0 | type LoginFlag uint8 const ( - // LoginFlagPrompt adds prompt=login to the authorization request - LoginFlagPrompt LoginFlag = iota + // LoginFlagPromptLogin adds prompt=login to the authorization request + LoginFlagPromptLogin LoginFlag = iota // LoginFlagMaxAge0 adds max_age=0 to the authorization request LoginFlagMaxAge0 + // LoginFlagNone disables all login flags + LoginFlagNone ) - -func (l LoginFlag) IsPromptLogin() bool { - return l == LoginFlagPrompt -} - -func (l LoginFlag) IsMaxAge0Login() bool { - return l == LoginFlagMaxAge0 -} diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..89860ac9b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } @@ -112,6 +111,8 @@ func (c *GrpcClient) ready() bool { // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { + backOff := defaultBackoff(ctx) + operation := func() error { log.Debugf("management connection state %v", c.conn.GetState()) connState := c.conn.GetState() @@ -129,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return err } - return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) + return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff) } - err := backoff.Retry(operation, defaultBackoff(ctx)) + err := backoff.Retry(operation, backOff) if err != nil { log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err) } @@ -141,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler } func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, - msgHandler func(msg *proto.SyncResponse) error) error { + msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error { ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() @@ -159,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, // blocking until error err = c.receiveEvents(stream, serverPubKey, msgHandler) + // we need this reset because after a successful connection and a consequent error, backoff lib doesn't + // reset times and next try will start with a long delay + backOff.Reset() if err != nil { c.notifyDisconnected(err) s, _ := gstatus.FromError(err) diff --git a/shared/management/client/rest/groups.go b/shared/management/client/rest/groups.go index af068e077..7cd9535dd 100644 --- a/shared/management/client/rest/groups.go +++ b/shared/management/client/rest/groups.go @@ -4,10 +4,14 @@ import ( "bytes" "context" "encoding/json" + "errors" "github.com/netbirdio/netbird/shared/management/http/api" ) +// ErrGroupNotFound is returned when a group is not found +var ErrGroupNotFound = errors.New("group not found") + // GroupsAPI APIs for Groups, do not use directly type GroupsAPI struct { c *Client @@ -27,6 +31,27 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { return ret, err } +// GetByName get group by name +// See more: https://docs.netbird.io/api/resources/groups#list-all-groups +func (a *GroupsAPI) GetByName(ctx context.Context, groupName string) (*api.Group, error) { + params := map[string]string{"name": groupName} + resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, params) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.Group](resp) + if err != nil { + return nil, err + } + if len(ret) == 0 { + return nil, ErrGroupNotFound + } + return &ret[0], nil +} + // Get get group info // See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) { diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..c9edcdda6 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -145,6 +145,10 @@ components: description: Enables or disables experimental lazy connection type: boolean example: true + auto_update_version: + description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + type: string + example: "0.51.2" required: - peer_login_expiration_enabled - peer_login_expiration @@ -463,6 +467,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: @@ -481,6 +488,8 @@ components: description: Indicates whether the peer is ephemeral or not type: boolean example: false + local_flags: + $ref: '#/components/schemas/PeerLocalFlags' required: - city_name - connected @@ -507,6 +516,49 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerLocalFlags: + type: object + properties: + rosenpass_enabled: + description: Indicates whether Rosenpass is enabled on this peer + type: boolean + example: true + rosenpass_permissive: + description: Indicates whether Rosenpass is in permissive mode or not + type: boolean + example: false + server_ssh_allowed: + description: Indicates whether SSH access this peer is allowed or not + type: boolean + example: true + disable_client_routes: + description: Indicates whether client routes are disabled on this peer or not + type: boolean + example: false + disable_server_routes: + description: Indicates whether server routes are disabled on this peer or not + type: boolean + example: false + disable_dns: + description: Indicates whether DNS management is disabled on this peer or not + type: boolean + example: false + disable_firewall: + description: Indicates whether firewall management is disabled on this peer or not + type: boolean + example: false + block_lan_access: + description: Indicates whether LAN access is blocked on this peer when used as a routing peer + type: boolean + example: false + block_inbound: + description: Indicates whether inbound traffic is blocked on this peer + type: boolean + example: false + lazy_connection_enabled: + description: Indicates whether lazy connection is enabled on this peer + type: boolean + example: false PeerTemporaryAccessRequest: type: object properties: @@ -929,7 +981,7 @@ components: protocol: description: Policy rule type of the traffic type: string - enum: ["all", "tcp", "udp", "icmp"] + enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"] example: "tcp" ports: description: Policy rule affected ports @@ -942,6 +994,14 @@ components: type: array items: $ref: '#/components/schemas/RulePortRange' + authorized_groups: + description: Map of user group ids to a list of local users + type: object + additionalProperties: + type: array + items: + type: string + example: "group1" required: - name - enabled @@ -3359,6 +3419,14 @@ paths: security: - BearerAuth: [ ] - TokenAuth: [ ] + parameters: + - in: query + name: name + required: false + schema: + type: string + description: Filter groups by name (exact match) + example: "devs" responses: '200': description: A JSON Array of Groups @@ -3372,6 +3440,8 @@ paths: "$ref": "#/components/responses/bad_request" '401': "$ref": "#/components/responses/requires_authentication" + '404': + "$ref": "#/components/responses/not_found" '403': "$ref": "#/components/responses/forbidden" '500': diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..f242f5a18 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -130,10 +130,11 @@ const ( // Defines values for PolicyRuleProtocol. const ( - PolicyRuleProtocolAll PolicyRuleProtocol = "all" - PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" - PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" - PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" + PolicyRuleProtocolAll PolicyRuleProtocol = "all" + PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" + PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh" + PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" + PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" ) // Defines values for PolicyRuleMinimumAction. @@ -144,10 +145,11 @@ const ( // Defines values for PolicyRuleMinimumProtocol. const ( - PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" - PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" - PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" - PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" + PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" + PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" + PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh" + PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" + PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" ) // Defines values for PolicyRuleUpdateAction. @@ -158,10 +160,11 @@ const ( // Defines values for PolicyRuleUpdateProtocol. const ( - PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" - PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" - PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" - PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" + PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" + PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" + PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh" + PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" + PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) // Defines values for ResourceType. @@ -291,6 +294,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + AutoUpdateVersion *string `json:"auto_update_version,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account DnsDomain *string `json:"dns_domain,omitempty"` Extra *AccountExtraSettings `json:"extra,omitempty"` @@ -1037,6 +1043,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1071,7 +1080,8 @@ type Peer struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1124,6 +1134,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1158,7 +1171,8 @@ type PeerBatch struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1188,6 +1202,39 @@ type PeerBatch struct { Version string `json:"version"` } +// PeerLocalFlags defines model for PeerLocalFlags. +type PeerLocalFlags struct { + // BlockInbound Indicates whether inbound traffic is blocked on this peer + BlockInbound *bool `json:"block_inbound,omitempty"` + + // BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer + BlockLanAccess *bool `json:"block_lan_access,omitempty"` + + // DisableClientRoutes Indicates whether client routes are disabled on this peer or not + DisableClientRoutes *bool `json:"disable_client_routes,omitempty"` + + // DisableDns Indicates whether DNS management is disabled on this peer or not + DisableDns *bool `json:"disable_dns,omitempty"` + + // DisableFirewall Indicates whether firewall management is disabled on this peer or not + DisableFirewall *bool `json:"disable_firewall,omitempty"` + + // DisableServerRoutes Indicates whether server routes are disabled on this peer or not + DisableServerRoutes *bool `json:"disable_server_routes,omitempty"` + + // LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer + LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + + // RosenpassEnabled Indicates whether Rosenpass is enabled on this peer + RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"` + + // RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not + RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"` + + // ServerSshAllowed Indicates whether SSH access this peer is allowed or not + ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"` +} + // PeerMinimum defines model for PeerMinimum. type PeerMinimum struct { // Id Peer ID @@ -1340,6 +1387,9 @@ type PolicyRule struct { // Action Policy rule accept or drops packets Action PolicyRuleAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1384,6 +1434,9 @@ type PolicyRuleMinimum struct { // Action Policy rule accept or drops packets Action PolicyRuleMinimumAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1417,6 +1470,9 @@ type PolicyRuleUpdate struct { // Action Policy rule accept or drops packets Action PolicyRuleUpdateAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1902,6 +1958,12 @@ type GetApiEventsNetworkTrafficParamsConnectionType string // GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParamsDirection string +// GetApiGroupsParams defines parameters for GetApiGroups. +type GetApiGroupsParams struct { + // Name Filter groups by name (exact match) + Name *string `form:"name,omitempty" json:"name,omitempty"` +} + // GetApiPeersParams defines parameters for GetApiPeers. type GetApiPeersParams struct { // Name Filter peers by name diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go index 3ae321023..0a29469da 100644 --- a/shared/management/http/util/util.go +++ b/shared/management/http/util/util.go @@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized case status.BadRequest: httpStatus = http.StatusBadRequest + case status.TooManyRequests: + httpStatus = http.StatusTooManyRequests default: } msg = strings.ToLower(err.Error()) diff --git a/shared/management/operations/operation.go b/shared/management/operations/operation.go index b9b500362..b1ba12815 100644 --- a/shared/management/operations/operation.go +++ b/shared/management/operations/operation.go @@ -1,4 +1,4 @@ package operations // Operation represents a permission operation type -type Operation string \ No newline at end of file +type Operation string diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 0de00ec0c..2047c51ea 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,19 +1,18 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.0 +// protoc v6.33.1 // source: management.proto package proto import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" ) const ( @@ -268,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23, 0} + return file_management_proto_rawDescGZIP(), []int{27, 0} } type EncryptedMessage struct { @@ -799,16 +798,21 @@ type Flags struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` - DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` - DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` - BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` - BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` + DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` + DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` + BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` + BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + EnableSSHRoot bool `protobuf:"varint,11,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,12,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` } func (x *Flags) Reset() { @@ -913,6 +917,41 @@ func (x *Flags) GetLazyConnectionEnabled() bool { return false } +func (x *Flags) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *Flags) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *Flags) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *Flags) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *Flags) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -1568,6 +1607,78 @@ func (x *FlowConfig) GetDnsCollection() bool { return false } +// JWTConfig represents JWT authentication configuration +type JWTConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Issuer string `protobuf:"bytes,1,opt,name=issuer,proto3" json:"issuer,omitempty"` + Audience string `protobuf:"bytes,2,opt,name=audience,proto3" json:"audience,omitempty"` + KeysLocation string `protobuf:"bytes,3,opt,name=keysLocation,proto3" json:"keysLocation,omitempty"` + MaxTokenAge int64 `protobuf:"varint,4,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` +} + +func (x *JWTConfig) Reset() { + *x = JWTConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *JWTConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*JWTConfig) ProtoMessage() {} + +func (x *JWTConfig) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[17] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use JWTConfig.ProtoReflect.Descriptor instead. +func (*JWTConfig) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{17} +} + +func (x *JWTConfig) GetIssuer() string { + if x != nil { + return x.Issuer + } + return "" +} + +func (x *JWTConfig) GetAudience() string { + if x != nil { + return x.Audience + } + return "" +} + +func (x *JWTConfig) GetKeysLocation() string { + if x != nil { + return x.KeysLocation + } + return "" +} + +func (x *JWTConfig) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers type ProtectedHostConfig struct { @@ -1583,7 +1694,7 @@ type ProtectedHostConfig struct { func (x *ProtectedHostConfig) Reset() { *x = ProtectedHostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1596,7 +1707,7 @@ func (x *ProtectedHostConfig) String() string { func (*ProtectedHostConfig) ProtoMessage() {} func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1609,7 +1720,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead. func (*ProtectedHostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{17} + return file_management_proto_rawDescGZIP(), []int{18} } func (x *ProtectedHostConfig) GetHostConfig() *HostConfig { @@ -1651,12 +1762,14 @@ type PeerConfig struct { RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"` Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` + // Auto-update config + AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"` } func (x *PeerConfig) Reset() { *x = PeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1669,7 +1782,7 @@ func (x *PeerConfig) String() string { func (*PeerConfig) ProtoMessage() {} func (x *PeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1682,7 +1795,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. func (*PeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{18} + return file_management_proto_rawDescGZIP(), []int{19} } func (x *PeerConfig) GetAddress() string { @@ -1734,6 +1847,70 @@ func (x *PeerConfig) GetMtu() int32 { return 0 } +func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings { + if x != nil { + return x.AutoUpdate + } + return nil +} + +type AutoUpdateSettings struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` + // alwaysUpdate = true → Updates happen automatically in the background + // alwaysUpdate = false → Updates only happen when triggered by a peer connection + AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"` +} + +func (x *AutoUpdateSettings) Reset() { + *x = AutoUpdateSettings{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AutoUpdateSettings) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AutoUpdateSettings) ProtoMessage() {} + +func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[20] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AutoUpdateSettings.ProtoReflect.Descriptor instead. +func (*AutoUpdateSettings) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{20} +} + +func (x *AutoUpdateSettings) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *AutoUpdateSettings) GetAlwaysUpdate() bool { + if x != nil { + return x.AlwaysUpdate + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -1765,12 +1942,14 @@ type NetworkMap struct { // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` ForwardingRules []*ForwardingRule `protobuf:"bytes,12,rep,name=forwardingRules,proto3" json:"forwardingRules,omitempty"` + // SSHAuth represents SSH authorization configuration + SshAuth *SSHAuth `protobuf:"bytes,13,opt,name=sshAuth,proto3" json:"sshAuth,omitempty"` } func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1783,7 +1962,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1796,7 +1975,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{19} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *NetworkMap) GetSerial() uint64 { @@ -1883,6 +2062,126 @@ func (x *NetworkMap) GetForwardingRules() []*ForwardingRule { return nil } +func (x *NetworkMap) GetSshAuth() *SSHAuth { + if x != nil { + return x.SshAuth + } + return nil +} + +type SSHAuth struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // UserIDClaim is the JWT claim to be used to get the users ID + UserIDClaim string `protobuf:"bytes,1,opt,name=UserIDClaim,proto3" json:"UserIDClaim,omitempty"` + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + AuthorizedUsers [][]byte `protobuf:"bytes,2,rep,name=AuthorizedUsers,proto3" json:"AuthorizedUsers,omitempty"` + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + MachineUsers map[string]*MachineUserIndexes `protobuf:"bytes,3,rep,name=machine_users,json=machineUsers,proto3" json:"machine_users,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *SSHAuth) Reset() { + *x = SSHAuth{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SSHAuth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHAuth) ProtoMessage() {} + +func (x *SSHAuth) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[22] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHAuth.ProtoReflect.Descriptor instead. +func (*SSHAuth) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{22} +} + +func (x *SSHAuth) GetUserIDClaim() string { + if x != nil { + return x.UserIDClaim + } + return "" +} + +func (x *SSHAuth) GetAuthorizedUsers() [][]byte { + if x != nil { + return x.AuthorizedUsers + } + return nil +} + +func (x *SSHAuth) GetMachineUsers() map[string]*MachineUserIndexes { + if x != nil { + return x.MachineUsers + } + return nil +} + +type MachineUserIndexes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Indexes []uint32 `protobuf:"varint,1,rep,packed,name=indexes,proto3" json:"indexes,omitempty"` +} + +func (x *MachineUserIndexes) Reset() { + *x = MachineUserIndexes{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MachineUserIndexes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MachineUserIndexes) ProtoMessage() {} + +func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MachineUserIndexes.ProtoReflect.Descriptor instead. +func (*MachineUserIndexes) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{23} +} + +func (x *MachineUserIndexes) GetIndexes() []uint32 { + if x != nil { + return x.Indexes + } + return nil +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -1904,7 +2203,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1917,7 +2216,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1930,7 +2229,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -1978,13 +2277,14 @@ type SSHConfig struct { SshEnabled bool `protobuf:"varint,1,opt,name=sshEnabled,proto3" json:"sshEnabled,omitempty"` // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. - SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + JwtConfig *JWTConfig `protobuf:"bytes,3,opt,name=jwtConfig,proto3" json:"jwtConfig,omitempty"` } func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1997,7 +2297,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2010,7 +2310,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2027,6 +2327,13 @@ func (x *SSHConfig) GetSshPubKey() []byte { return nil } +func (x *SSHConfig) GetJwtConfig() *JWTConfig { + if x != nil { + return x.JwtConfig + } + return nil +} + // DeviceAuthorizationFlowRequest empty struct for future expansion type DeviceAuthorizationFlowRequest struct { state protoimpl.MessageState @@ -2037,7 +2344,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2050,7 +2357,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2063,7 +2370,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{26} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2082,7 +2389,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2095,7 +2402,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2108,7 +2415,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2135,7 +2442,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2148,7 +2455,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2161,7 +2468,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{28} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2178,7 +2485,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2191,7 +2498,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2204,7 +2511,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2250,7 +2557,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2263,7 +2570,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2276,7 +2583,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *ProviderConfig) GetClientID() string { @@ -2384,7 +2691,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2397,7 +2704,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2410,7 +2717,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *Route) GetID() string { @@ -2492,13 +2799,14 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` - ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` + // Deprecated: Do not use. + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2511,7 +2819,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2524,7 +2832,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2548,6 +2856,7 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +// Deprecated: Do not use. func (x *DNSConfig) GetForwarderPort() int64 { if x != nil { return x.ForwarderPort @@ -2561,14 +2870,16 @@ type CustomZone struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` - Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` + Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"` + SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"` } func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2581,7 +2892,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2594,7 +2905,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *CustomZone) GetDomain() string { @@ -2611,6 +2922,20 @@ func (x *CustomZone) GetRecords() []*SimpleRecord { return nil } +func (x *CustomZone) GetSearchDomainDisabled() bool { + if x != nil { + return x.SearchDomainDisabled + } + return false +} + +func (x *CustomZone) GetSkipPTRProcess() bool { + if x != nil { + return x.SkipPTRProcess + } + return false +} + // SimpleRecord represents a dns.SimpleRecord type SimpleRecord struct { state protoimpl.MessageState @@ -2627,7 +2952,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2640,7 +2965,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2653,7 +2978,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *SimpleRecord) GetName() string { @@ -2706,7 +3031,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2719,7 +3044,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2732,7 +3057,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2777,7 +3102,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2790,7 +3115,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2803,7 +3128,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *NameServer) GetIP() string { @@ -2846,7 +3171,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2859,7 +3184,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2872,7 +3197,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{37} } func (x *FirewallRule) GetPeerIP() string { @@ -2936,7 +3261,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2949,7 +3274,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2962,7 +3287,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *NetworkAddress) GetNetIP() string { @@ -2990,7 +3315,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3003,7 +3328,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3016,7 +3341,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *Checks) GetFiles() []string { @@ -3041,7 +3366,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3054,7 +3379,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3067,7 +3392,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{40} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3138,7 +3463,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3151,7 +3476,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[41] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3164,7 +3489,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{41} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3255,7 +3580,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3268,7 +3593,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[42] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3281,7 +3606,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{42} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3324,7 +3649,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3337,7 +3662,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[44] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3350,7 +3675,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36, 0} + return file_management_proto_rawDescGZIP(), []int{40, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3438,7 +3763,7 @@ var file_management_proto_rawDesc = []byte{ 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, - 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xc1, 0x03, 0x0a, 0x05, + 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, @@ -3466,435 +3791,500 @@ var file_management_proto_rawDesc = []byte{ 0x63, 0x6b, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, - 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, - 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, - 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, - 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, - 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, - 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, - 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, - 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, - 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, - 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, - 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, - 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, - 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, - 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, - 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, - 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, - 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, - 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, - 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, - 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, - 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, - 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, - 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, - 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, - 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, - 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, - 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, - 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, - 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, - 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, - 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, - 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, - 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, - 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x7d, 0x0a, 0x13, 0x50, - 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, - 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, - 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, - 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, - 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, - 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, - 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, - 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, - 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, - 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, - 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, - 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, - 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, - 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, - 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, - 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, - 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, - 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x6f, 0x6f, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, + 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x12, 0x42, 0x0a, 0x1c, 0x65, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x6f, 0x72, + 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x1c, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, + 0x44, 0x0a, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04, + 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, + 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, + 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, + 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, + 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, + 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, + 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, + 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, + 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, + 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, + 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, + 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, + 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, + 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, + 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, + 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, + 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, + 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, + 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, + 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, + 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, + 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, + 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, + 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, + 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, + 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, + 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, + 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, + 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, + 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, + 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, + 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, + 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, + 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, + 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, + 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, + 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, + 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, + 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, + 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, + 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, + 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, + 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, + 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, + 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, + 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, + 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, + 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, + 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, + 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, + 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, + 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, + 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, + 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, + 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, + 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, - 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24, - 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, - 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, - 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, - 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, - 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, - 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, + 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, + 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, + 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, + 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, + 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, + 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, + 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, + 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, + 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, - 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, - 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, - 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, - 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, - 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, - 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, - 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, + 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, + 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3910,7 +4300,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 40) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 45) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -3934,107 +4324,117 @@ var file_management_proto_goTypes = []interface{}{ (*HostConfig)(nil), // 19: management.HostConfig (*RelayConfig)(nil), // 20: management.RelayConfig (*FlowConfig)(nil), // 21: management.FlowConfig - (*ProtectedHostConfig)(nil), // 22: management.ProtectedHostConfig - (*PeerConfig)(nil), // 23: management.PeerConfig - (*NetworkMap)(nil), // 24: management.NetworkMap - (*RemotePeerConfig)(nil), // 25: management.RemotePeerConfig - (*SSHConfig)(nil), // 26: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 27: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 28: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 29: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 30: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 31: management.ProviderConfig - (*Route)(nil), // 32: management.Route - (*DNSConfig)(nil), // 33: management.DNSConfig - (*CustomZone)(nil), // 34: management.CustomZone - (*SimpleRecord)(nil), // 35: management.SimpleRecord - (*NameServerGroup)(nil), // 36: management.NameServerGroup - (*NameServer)(nil), // 37: management.NameServer - (*FirewallRule)(nil), // 38: management.FirewallRule - (*NetworkAddress)(nil), // 39: management.NetworkAddress - (*Checks)(nil), // 40: management.Checks - (*PortInfo)(nil), // 41: management.PortInfo - (*RouteFirewallRule)(nil), // 42: management.RouteFirewallRule - (*ForwardingRule)(nil), // 43: management.ForwardingRule - (*PortInfo_Range)(nil), // 44: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 45: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 46: google.protobuf.Duration + (*JWTConfig)(nil), // 22: management.JWTConfig + (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig + (*PeerConfig)(nil), // 24: management.PeerConfig + (*AutoUpdateSettings)(nil), // 25: management.AutoUpdateSettings + (*NetworkMap)(nil), // 26: management.NetworkMap + (*SSHAuth)(nil), // 27: management.SSHAuth + (*MachineUserIndexes)(nil), // 28: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 29: management.RemotePeerConfig + (*SSHConfig)(nil), // 30: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 31: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 32: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 33: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 34: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 35: management.ProviderConfig + (*Route)(nil), // 36: management.Route + (*DNSConfig)(nil), // 37: management.DNSConfig + (*CustomZone)(nil), // 38: management.CustomZone + (*SimpleRecord)(nil), // 39: management.SimpleRecord + (*NameServerGroup)(nil), // 40: management.NameServerGroup + (*NameServer)(nil), // 41: management.NameServer + (*FirewallRule)(nil), // 42: management.FirewallRule + (*NetworkAddress)(nil), // 43: management.NetworkAddress + (*Checks)(nil), // 44: management.Checks + (*PortInfo)(nil), // 45: management.PortInfo + (*RouteFirewallRule)(nil), // 46: management.RouteFirewallRule + (*ForwardingRule)(nil), // 47: management.ForwardingRule + nil, // 48: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 49: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 50: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 51: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 25, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 24, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 40, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 29, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 26, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 44, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 39, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 43, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 40, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 45, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 44, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 50, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 22, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 46, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 51, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 26, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 23, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 25, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 32, // 28: management.NetworkMap.Routes:type_name -> management.Route - 33, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 25, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 38, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 42, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 43, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 26, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 4, // 35: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 31, // 36: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 31, // 37: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 36, // 38: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 34, // 39: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 35, // 40: management.CustomZone.Records:type_name -> management.SimpleRecord - 37, // 41: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 42: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 43: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 44: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 41, // 45: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 44, // 46: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 47: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 48: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 41, // 49: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 50: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 41, // 51: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 41, // 52: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 53: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 54: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 55: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 56: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 68: management.ManagementService.Logout:output_type -> management.Empty - 61, // [61:69] is the sub-list for method output_type - 53, // [53:61] is the sub-list for method input_type - 53, // [53:53] is the sub-list for extension type_name - 53, // [53:53] is the sub-list for extension extendee - 0, // [0:53] is the sub-list for field type_name + 30, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 25, // 26: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 29, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 36, // 29: management.NetworkMap.Routes:type_name -> management.Route + 37, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 29, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 42, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 46, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 47, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 35: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 48, // 36: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 30, // 37: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 38: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 39: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 35, // 40: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 35, // 41: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 40, // 42: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 38, // 43: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 39, // 44: management.CustomZone.Records:type_name -> management.SimpleRecord + 41, // 45: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 46: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 47: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 48: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 45, // 49: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 49, // 50: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 51: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 52: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 45, // 53: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 54: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 45, // 55: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 45, // 56: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 28, // 57: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 5, // 58: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 59: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 60: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 61: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 62: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 65: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 66: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 67: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 68: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 69: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 70: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 71: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 72: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 73: management.ManagementService.Logout:output_type -> management.Empty + 66, // [66:74] is the sub-list for method output_type + 58, // [58:66] is the sub-list for method input_type + 58, // [58:58] is the sub-list for extension type_name + 58, // [58:58] is the sub-list for extension extendee + 0, // [0:58] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4248,7 +4648,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*JWTConfig); i { case 0: return &v.state case 1: @@ -4260,7 +4660,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -4272,7 +4672,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -4284,7 +4684,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*AutoUpdateSettings); i { case 0: return &v.state case 1: @@ -4296,7 +4696,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -4308,7 +4708,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHAuth); i { case 0: return &v.state case 1: @@ -4320,7 +4720,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*MachineUserIndexes); i { case 0: return &v.state case 1: @@ -4332,7 +4732,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4344,7 +4744,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4356,7 +4756,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4368,7 +4768,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4380,7 +4780,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4392,7 +4792,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4404,7 +4804,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4416,7 +4816,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4428,7 +4828,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4440,7 +4840,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4452,7 +4852,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4464,7 +4864,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4476,7 +4876,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4488,7 +4888,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4500,7 +4900,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4512,6 +4912,54 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Checks); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[44].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4524,7 +4972,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[36].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[40].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4534,7 +4982,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 40, + NumMessages: 45, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad82d37d9..f2e591e88 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -146,6 +146,12 @@ message Flags { bool blockInbound = 9; bool lazyConnectionEnabled = 10; + + bool enableSSHRoot = 11; + bool enableSSHSFTP = 12; + bool enableSSHLocalPortForwarding = 13; + bool enableSSHRemotePortForwarding = 14; + bool disableSSHAuth = 15; } // PeerSystemMeta is machine meta data like OS and version. @@ -240,6 +246,14 @@ message FlowConfig { bool dnsCollection = 8; } +// JWTConfig represents JWT authentication configuration +message JWTConfig { + string issuer = 1; + string audience = 2; + string keysLocation = 3; + int64 maxTokenAge = 4; +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers message ProtectedHostConfig { @@ -266,6 +280,18 @@ message PeerConfig { bool LazyConnectionEnabled = 6; int32 mtu = 7; + + // Auto-update config + AutoUpdateSettings autoUpdate = 8; +} + +message AutoUpdateSettings { + string version = 1; + /* + alwaysUpdate = true → Updates happen automatically in the background + alwaysUpdate = false → Updates only happen when triggered by a peer connection + */ + bool alwaysUpdate = 2; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -306,6 +332,24 @@ message NetworkMap { bool routesFirewallRulesIsEmpty = 11; repeated ForwardingRule forwardingRules = 12; + + // SSHAuth represents SSH authorization configuration + SSHAuth sshAuth = 13; +} + +message SSHAuth { + // UserIDClaim is the JWT claim to be used to get the users ID + string UserIDClaim = 1; + + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + repeated bytes AuthorizedUsers = 2; + + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + map machine_users = 3; +} + +message MachineUserIndexes { + repeated uint32 indexes = 1; } // RemotePeerConfig represents a configuration of a remote peer. @@ -335,6 +379,8 @@ message SSHConfig { // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. bytes sshPubKey = 2; + + JWTConfig jwtConfig = 3; } // DeviceAuthorizationFlowRequest empty struct for future expansion @@ -410,13 +456,15 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone message CustomZone { string Domain = 1; repeated SimpleRecord Records = 2; + bool SearchDomainDisabled = 3; + bool SkipPTRProcess = 4; } // SimpleRecord represents a dns.SimpleRecord diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..09676847e 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -37,6 +37,9 @@ const ( // Unauthenticated indicates that user is not authenticated due to absence of valid credentials Unauthenticated Type = 10 + + // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting) + TooManyRequests Type = 11 ) // Type is a type of the Error diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 967e18d79..c057ef089 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -11,8 +11,8 @@ import ( "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" - quictls "github.com/netbirdio/netbird/shared/relay/tls" nbnet "github.com/netbirdio/netbird/client/net" + quictls "github.com/netbirdio/netbird/shared/relay/tls" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 66fff3447..37b189e05 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -14,9 +14,9 @@ import ( "github.com/coder/websocket" log "github.com/sirupsen/logrus" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/constants.go b/shared/relay/constants.go index 3c7c3cd29..0f2a27610 100644 --- a/shared/relay/constants.go +++ b/shared/relay/constants.go @@ -3,4 +3,4 @@ package relay const ( // WebSocketURLPath is the path for the websocket relay connection WebSocketURLPath = "/relay" -) \ No newline at end of file +) diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/sshauth/userhash.go b/shared/sshauth/userhash.go new file mode 100644 index 000000000..276fc9ba2 --- /dev/null +++ b/shared/sshauth/userhash.go @@ -0,0 +1,28 @@ +package sshauth + +import ( + "encoding/hex" + + "golang.org/x/crypto/blake2b" +) + +// UserIDHash represents a hashed user ID (BLAKE2b-128) +type UserIDHash [16]byte + +// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value +// This function must produce the same hash on both client and management server +func HashUserID(userID string) (UserIDHash, error) { + hash, err := blake2b.New(16, nil) + if err != nil { + return UserIDHash{}, err + } + hash.Write([]byte(userID)) + var result UserIDHash + copy(result[:], hash.Sum(nil)) + return result, nil +} + +// String returns the hexadecimal string representation of the hash +func (h UserIDHash) String() string { + return hex.EncodeToString(h[:]) +} diff --git a/shared/sshauth/userhash_test.go b/shared/sshauth/userhash_test.go new file mode 100644 index 000000000..5a3cb6986 --- /dev/null +++ b/shared/sshauth/userhash_test.go @@ -0,0 +1,210 @@ +package sshauth + +import ( + "testing" +) + +func TestHashUserID(t *testing.T) { + tests := []struct { + name string + userID string + }{ + { + name: "simple user ID", + userID: "user@example.com", + }, + { + name: "UUID format", + userID: "550e8400-e29b-41d4-a716-446655440000", + }, + { + name: "numeric ID", + userID: "12345", + }, + { + name: "empty string", + userID: "", + }, + { + name: "special characters", + userID: "user+test@domain.com", + }, + { + name: "unicode characters", + userID: "用户@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v, want nil", err) + return + } + + // Verify hash is non-zero for non-empty inputs + if tt.userID != "" && hash == [16]byte{} { + t.Errorf("HashUserID() returned zero hash for non-empty input") + } + }) + } +} + +func TestHashUserID_Consistency(t *testing.T) { + userID := "test@example.com" + + hash1, err1 := HashUserID(userID) + if err1 != nil { + t.Fatalf("First HashUserID() error = %v", err1) + } + + hash2, err2 := HashUserID(userID) + if err2 != nil { + t.Fatalf("Second HashUserID() error = %v", err2) + } + + if hash1 != hash2 { + t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2) + } +} + +func TestHashUserID_Uniqueness(t *testing.T) { + tests := []struct { + userID1 string + userID2 string + }{ + {"user1@example.com", "user2@example.com"}, + {"alice@domain.com", "bob@domain.com"}, + {"test", "test1"}, + {"", "a"}, + } + + for _, tt := range tests { + hash1, err1 := HashUserID(tt.userID1) + if err1 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1) + } + + hash2, err2 := HashUserID(tt.userID2) + if err2 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2) + } + + if hash1 == hash2 { + t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1) + } + } +} + +func TestUserIDHash_String(t *testing.T) { + tests := []struct { + name string + hash UserIDHash + expected string + }{ + { + name: "zero hash", + hash: [16]byte{}, + expected: "00000000000000000000000000000000", + }, + { + name: "small value", + hash: [16]byte{15: 0xff}, + expected: "000000000000000000000000000000ff", + }, + { + name: "large value", + hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}, + expected: "0000000000000000deadbeefcafebabe", + }, + { + name: "max value", + hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expected: "ffffffffffffffffffffffffffffffff", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.hash.String() + if result != tt.expected { + t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUserIDHash_String_Length(t *testing.T) { + // Test that String() always returns 32 hex characters (16 bytes * 2) + userID := "test@example.com" + hash, err := HashUserID(userID) + if err != nil { + t.Fatalf("HashUserID() error = %v", err) + } + + result := hash.String() + if len(result) != 32 { + t.Errorf("UserIDHash.String() length = %d, want 32", len(result)) + } + + // Verify it's valid hex + for i, c := range result { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c) + } + } +} + +func TestHashUserID_KnownValues(t *testing.T) { + // Test with known BLAKE2b-128 values to ensure correct implementation + tests := []struct { + name string + userID string + expected UserIDHash + }{ + { + name: "empty string", + userID: "", + // BLAKE2b-128 of empty string + expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70}, + }, + { + name: "single character 'a'", + userID: "a", + // BLAKE2b-128 of "a" + expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v", err) + return + } + + if hash != tt.expected { + t.Errorf("HashUserID(%q) = %x, want %x", + tt.userID, hash, tt.expected) + } + }) + } +} + +func BenchmarkHashUserID(b *testing.B) { + userID := "user@example.com" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = HashUserID(userID) + } +} + +func BenchmarkUserIDHash_String(b *testing.B) { + hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = hash.String() + } +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { diff --git a/util/common.go b/util/common.go index 27adb9d13..89903b609 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,19 @@ package util -import "os" +import ( + "os" + "os/exec" + + "github.com/skratchdot/open-golang/open" +) + +// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func OpenBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} // SliceDiff returns the elements in slice `x` that are not in slice `y` func SliceDiff(x, y []string) []string { diff --git a/version/update.go b/version/update.go index 272eef4c6..a324d97fe 100644 --- a/version/update.go +++ b/version/update.go @@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update { currentVersion, _ = goversion.NewVersion("0.0.0") } - latestAvailable, _ := goversion.NewVersion("0.0.0") - u := &Update{ - httpAgent: httpAgent, - latestAvailable: latestAvailable, - uiVersion: currentVersion, - fetchTicker: time.NewTicker(fetchPeriod), - fetchDone: make(chan struct{}), + httpAgent: httpAgent, + uiVersion: currentVersion, + fetchDone: make(chan struct{}), } - go u.startFetcher() + + return u +} + +func NewUpdateAndStart(httpAgent string) *Update { + u := NewUpdate(httpAgent) + go u.StartFetcher() + return u } // StopWatch stop the version info fetch loop func (u *Update) StopWatch() { + if u.fetchTicker == nil { + return + } + u.fetchTicker.Stop() select { @@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) { } } -func (u *Update) startFetcher() { +func (u *Update) LatestVersion() *goversion.Version { + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + return u.latestAvailable +} + +func (u *Update) StartFetcher() { + if u.fetchTicker != nil { + return + } + u.fetchTicker = time.NewTicker(fetchPeriod) + if changed := u.fetchVersion(); changed { u.checkUpdate() } @@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool { u.versionsLock.Lock() defer u.versionsLock.Unlock() + if u.latestAvailable == nil { + return false + } + if u.latestAvailable.GreaterThan(u.uiVersion) { return true } diff --git a/version/update_test.go b/version/update_test.go index a733714cf..d5d60800e 100644 --- a/version/update_test.go +++ b/version/update_test.go @@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true diff --git a/version/url_windows.go b/version/url_windows.go index 14fdb7ae6..a0fb6e5dd 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -6,7 +6,7 @@ import ( ) const ( - urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExe = "https://pkgs.netbird.io/windows/x64" urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) @@ -18,11 +18,11 @@ func DownloadUrl() string { if err != nil { return downloadURL } - + url := urlWinExe if runtime.GOARCH == "arm64" { url = urlWinExeArm } - + return url }