diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..e9ffaf8a3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,14 @@ +blank_issues_enabled: true +contact_links: + - name: Community Support + url: https://forum.netbird.io/ + about: Community support forum + - name: Cloud Support + url: https://docs.netbird.io/help/report-bug-issues + about: Contact us for support + - name: Client/Connection Troubleshooting + url: https://docs.netbird.io/help/troubleshooting-client + about: See our client troubleshooting guide for help addressing common issues + - name: Self-host Troubleshooting + url: https://docs.netbird.io/selfhosted/troubleshooting + about: See our self-host troubleshooting guide for help addressing common issues diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d1d2a8e50..a721cb516 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -31,7 +31,7 @@ jobs: while IFS= read -r dir; do echo "=== Checking $dir ===" # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true) if [ -n "$RESULTS" ]; then echo "❌ Found problematic dependencies:" echo "$RESULTS" @@ -88,7 +88,7 @@ jobs: IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") # Check if any importer is NOT in management/signal/relay - BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1) + BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1) if [ -n "$BSD_IMPORTER" ]; then echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 8af4046a7..8e672043d 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,10 +63,15 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV + - name: Generate test script + run: | + $packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } + $goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe" + $cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1" + Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 56450d45f..62dfe9bce 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA skip: go.mod,go.sum,**/proxy/web/** golangci: strategy: diff --git a/.github/workflows/pr-title-check.yml b/.github/workflows/pr-title-check.yml new file mode 100644 index 000000000..a2e6ce219 --- /dev/null +++ b/.github/workflows/pr-title-check.yml @@ -0,0 +1,51 @@ +name: PR Title Check + +on: + pull_request: + types: [opened, edited, synchronize, reopened] + +jobs: + check-title: + runs-on: ubuntu-latest + steps: + - name: Validate PR title prefix + uses: actions/github-script@v7 + with: + script: | + const title = context.payload.pull_request.title; + const allowedTags = [ + 'management', + 'client', + 'signal', + 'proxy', + 'relay', + 'misc', + 'infrastructure', + 'self-hosted', + 'doc', + ]; + + const pattern = /^\[([^\]]+)\]\s+.+/; + const match = title.match(pattern); + + if (!match) { + core.setFailed( + `PR title must start with a tag in brackets.\n` + + `Example: [client] fix something\n` + + `Allowed tags: ${allowedTags.join(', ')}` + ); + return; + } + + const tags = match[1].split(',').map(t => t.trim().toLowerCase()); + + const invalid = tags.filter(t => !allowedTags.includes(t)); + if (invalid.length > 0) { + core.setFailed( + `Invalid tag(s): ${invalid.join(', ')}\n` + + `Allowed tags: ${allowedTags.join(', ')}` + ); + return; + } + + console.log(`Valid PR title tags: [${tags.join(', ')}]`); diff --git a/.github/workflows/proto-version-check.yml b/.github/workflows/proto-version-check.yml new file mode 100644 index 000000000..ea300419d --- /dev/null +++ b/.github/workflows/proto-version-check.yml @@ -0,0 +1,62 @@ +name: Proto Version Check + +on: + pull_request: + paths: + - "**/*.pb.go" + +jobs: + check-proto-versions: + runs-on: ubuntu-latest + steps: + - name: Check for proto tool version changes + uses: actions/github-script@v7 + with: + script: | + const files = await github.paginate(github.rest.pulls.listFiles, { + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.issue.number, + per_page: 100, + }); + + const pbFiles = files.filter(f => f.filename.endsWith('.pb.go')); + const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename); + if (missingPatch.length > 0) { + core.setFailed( + `Cannot inspect patch data for:\n` + + missingPatch.map(f => `- ${f}`).join('\n') + + `\nThis can happen with very large PRs. Verify proto versions manually.` + ); + return; + } + const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/; + const violations = []; + + for (const file of pbFiles) { + const changed = file.patch + .split('\n') + .filter(line => versionPattern.test(line)); + if (changed.length > 0) { + violations.push({ + file: file.filename, + lines: changed, + }); + } + } + + if (violations.length > 0) { + const details = violations.map(v => + `${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}` + ).join('\n\n'); + + core.setFailed( + `Proto version strings changed in generated files.\n` + + `This usually means the wrong protoc or protoc-gen-go version was used.\n` + + `Regenerate with the matching tool versions.\n\n` + + details + ); + return; + } + + console.log('No proto version string changes detected'); diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d1f085b47..5ada1033d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,8 +9,8 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.1.1" - GORELEASER_VER: "v2.3.2" + SIGN_PIPE_VER: "v0.1.2" + GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -169,6 +169,14 @@ jobs: - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu + - name: Decode GPG signing key + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso amd64 @@ -186,18 +194,54 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - - name: Tag and push PR images (amd64 only) - if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures run: | - PR_TAG="pr-${{ github.event.pull_request.number }}" + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*amd64*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc + - name: Tag and push images (amd64 only) + if: | + (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || + (github.event_name == 'push' && github.ref == 'refs/heads/main') + run: | + resolve_tags() { + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + echo "pr-${{ github.event.pull_request.number }}" + else + echo "main sha-$(git rev-parse --short HEAD)" + fi + } + + tag_and_push() { + local src="$1" img_name tag dst + img_name="${src%%:*}" + for tag in $(resolve_tags); do + dst="${img_name}:${tag}" + echo "Tagging ${src} -> ${dst}" + docker tag "$src" "$dst" + docker push "$dst" + done + } + + export -f tag_and_push resolve_tags + echo '${{ steps.goreleaser.outputs.artifacts }}' | \ jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ grep '^ghcr.io/' | while read -r SRC; do - IMG_NAME="${SRC%%:*}" - DST="${IMG_NAME}:${PR_TAG}" - echo "Tagging ${SRC} -> ${DST}" - docker tag "$SRC" "$DST" - docker push "$DST" + tag_and_push "$SRC" done - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 @@ -265,6 +309,14 @@ jobs: - name: Install dependencies run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 + - name: Decode GPG signing key + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install LLVM-MinGW for ARM64 cross-compilation run: | cd /tmp @@ -289,6 +341,24 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures + run: | + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index 47e45165b..81ae36e78 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -61,8 +61,8 @@ jobs: echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" - if [ ${SIZE} -gt 57671680 ]; then - echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" + if [ ${SIZE} -gt 58720256 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!" exit 1 fi diff --git a/.goreleaser.yaml b/.goreleaser.yaml index c0a5efbbe..5ea479148 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -154,6 +154,26 @@ builds: - -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}} mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird-idp-migrate + dir: tools/idp-migrate + env: + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} + binary: netbird-idp-migrate + goos: + - linux + goarch: + - amd64 + - arm64 + - arm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + universal_binaries: - id: netbird @@ -166,18 +186,22 @@ archives: - netbird-wasm name_template: "{{ .ProjectName }}_{{ .Version }}" format: binary + - id: netbird-idp-migrate + builds: + - netbird-idp-migrate + name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}" nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-deb + license: BSD-3-Clause + id: netbird_deb bindir: /usr/bin builds: - netbird formats: - deb - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" @@ -185,16 +209,19 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-rpm + license: BSD-3-Clause + id: netbird_rpm bindir: /usr/bin builds: - netbird formats: - rpm - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' dockers: - image_templates: - netbirdio/netbird:{{ .Version }}-amd64 @@ -876,7 +903,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -884,7 +911,7 @@ uploads: - name: yum ids: - - netbird-rpm + - netbird_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index a243702ea..470f1deaa 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -61,7 +61,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-deb + id: netbird_ui_deb package_name: netbird-ui builds: - netbird-ui @@ -80,7 +80,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-rpm + id: netbird_ui_rpm package_name: netbird-ui builds: - netbird-ui @@ -95,11 +95,14 @@ nfpms: dst: /usr/share/pixmaps/netbird.png dependencies: - netbird + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' uploads: - name: debian ids: - - netbird-ui-deb + - netbird_ui_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -107,7 +110,7 @@ uploads: - name: yum ids: - - netbird-ui-rpm + - netbird_ui_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com diff --git a/CONTRIBUTOR_LICENSE_AGREEMENT.md b/CONTRIBUTOR_LICENSE_AGREEMENT.md index 1fdd072c9..b0a6ee218 100644 --- a/CONTRIBUTOR_LICENSE_AGREEMENT.md +++ b/CONTRIBUTOR_LICENSE_AGREEMENT.md @@ -1,7 +1,7 @@ ## Contributor License Agreement This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual -submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany, +submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany, referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions under which NetBird may utilize software contributions provided by the Contributor for inclusion in its software development projects. By submitting this Agreement, the Contributor confirms their acceptance diff --git a/Makefile b/Makefile index 43379e115..5d52b94fa 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint $(GOLANGCI_LINT): @echo "Installing golangci-lint..." @mkdir -p ./bin - @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest # Lint only changed files (fast, for pre-push) lint: $(GOLANGCI_LINT) diff --git a/README.md b/README.md index bca81c20b..dc84af2fd 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how ### Community projects - [NetBird installer script](https://github.com/physk/netbird-installer) - [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/) +- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings **Note**: The `main` branch may be in an *unstable or even broken state* during development. For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). diff --git a/client/Dockerfile b/client/Dockerfile index 2ff0cca19..64d5ba04f 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.23.2 +FROM alpine:3.23.3 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ @@ -17,8 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ - NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="5" + NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 5fa8de0a5..69d00aaf2 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,8 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ - NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index ccf32a90c..37e17a363 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( "os" "slices" "sync" + "time" "golang.org/x/exp/maps" @@ -15,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -26,6 +28,7 @@ import ( "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" + types "github.com/netbirdio/netbird/upload-server/types" ) // ConnectionListener export internal Listener for mobile @@ -68,7 +71,30 @@ type Client struct { uiVersion string networkChangeListener listener.NetworkChangeListener + stateMu sync.RWMutex connectClient *internal.ConnectClient + config *profilemanager.Config + cacheDir string +} + +func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + c.config = cfg + c.cacheDir = cacheDir + c.connectClient = cc +} + +func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.config, c.cacheDir, c.connectClient +} + +func (c *Client) getConnectClient() *internal.ConnectClient { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.connectClient } // NewClient instantiate a new Client @@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) @@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) @@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // Stop the internal client and free the resources @@ -173,11 +203,12 @@ func (c *Client) Stop() { } func (c *Client) RenewTun(fd int) error { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { return fmt.Errorf("engine not running") } - e := c.connectClient.Engine() + e := cc.Engine() if e == nil { return fmt.Errorf("engine not initialized") } @@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error { return e.RenewTun(fd) } +// DebugBundle generates a debug bundle, uploads it, and returns the upload key. +// It works both with and without a running engine. +func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) { + cfg, cacheDir, cc := c.stateSnapshot() + + // If the engine hasn't been started, load config from disk + if cfg == nil { + var err error + cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ + ConfigPath: platformFiles.ConfigurationFilePath(), + }) + if err != nil { + return "", fmt.Errorf("load config: %w", err) + } + cacheDir = platformFiles.CacheDir() + } + + deps := debug.GeneratorDependencies{ + InternalConfig: cfg, + StatusRecorder: c.recorder, + TempDir: cacheDir, + } + + if cc != nil { + resp, err := cc.GetLatestSyncResponse() + if err != nil { + log.Warnf("get latest sync response: %v", err) + } + deps.SyncResponse = resp + + if e := cc.Engine(); e != nil { + if cm := e.GetClientMetrics(); cm != nil { + deps.ClientMetrics = cm + } + } + } + + bundleGenerator := debug.NewBundleGenerator( + deps, + debug.BundleConfig{ + Anonymize: anonymize, + IncludeSystemInfo: true, + }, + ) + + path, err := bundleGenerator.Generate() + if err != nil { + return "", fmt.Errorf("generate debug bundle: %w", err) + } + defer func() { + if err := os.Remove(path); err != nil { + log.Errorf("failed to remove debug bundle file: %v", err) + } + }() + + uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path) + if err != nil { + return "", fmt.Errorf("upload debug bundle: %w", err) + } + + log.Infof("debug bundle uploaded with key %s", key) + return key, nil +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -205,7 +303,7 @@ func (c *Client) PeersList() *PeerInfoArray { pi := PeerInfo{ p.IP, p.FQDN, - p.ConnStatus.String(), + int(p.ConnStatus), PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi @@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray { } func (c *Client) Networks() *NetworkArray { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { log.Error("not connected") return nil } - engine := c.connectClient.Engine() + engine := cc.Engine() if engine == nil { log.Error("could not get engine") return nil @@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error { } func (c *Client) getRouteManager() (routemanager.Manager, error) { - client := c.connectClient + client := c.getConnectClient() if client == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index b03947da1..4ec22f3ab 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -2,11 +2,20 @@ package android +import "github.com/netbirdio/netbird/client/internal/peer" + +// Connection status constants exported via gomobile. +const ( + ConnStatusIdle = int(peer.StatusIdle) + ConnStatusConnecting = int(peer.StatusConnecting) + ConnStatusConnected = int(peer.StatusConnected) +) + // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string FQDN string - ConnStatus string // Todo replace to enum + ConnStatus int Routes PeerRoutes } diff --git a/client/android/platform_files.go b/client/android/platform_files.go index f0c369750..3be40c0bd 100644 --- a/client/android/platform_files.go +++ b/client/android/platform_files.go @@ -7,4 +7,5 @@ package android type PlatformFiles interface { ConfigurationFilePath() string StateFilePath() string + CacheDir() string } diff --git a/client/cmd/debug.go b/client/cmd/debug.go index e480df4d7..e3d3afe5f 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { if stateWasDown { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { - return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up") + time.Sleep(time.Second * 10) } - cmd.Println("netbird up") - time.Sleep(time.Second * 10) } initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE @@ -198,10 +199,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level set to trace.") } + needsRestoreUp := false if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) + } else { + needsRestoreUp = !stateWasDown + cmd.Println("netbird down") } - cmd.Println("netbird down") time.Sleep(1 * time.Second) @@ -209,13 +213,15 @@ func runForDuration(cmd *cobra.Command, args []string) error { if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { - return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message()) } if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { - return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) + } else { + needsRestoreUp = false + cmd.Println("netbird up") } - cmd.Println("netbird up") time.Sleep(3 * time.Second) @@ -261,18 +267,28 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + if needsRestoreUp { + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { + cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up (restored)") + } + } + if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird down") } - cmd.Println("netbird down") } if !initialLevelTrace { if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil { - return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } - cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } cmd.Printf("Local file:\n%s\n", resp.GetPath()) diff --git a/client/cmd/expose.go b/client/cmd/expose.go new file mode 100644 index 000000000..c48a6adac --- /dev/null +++ b/client/cmd/expose.go @@ -0,0 +1,287 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/signal" + "regexp" + "strconv" + "strings" + "syscall" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal/expose" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util" +) + +var pinRegexp = regexp.MustCompile(`^\d{6}$`) + +var ( + exposePin string + exposePassword string + exposeUserGroups []string + exposeDomain string + exposeNamePrefix string + exposeProtocol string + exposeExternalPort uint16 +) + +var exposeCmd = &cobra.Command{ + Use: "expose ", + Short: "Expose a local port via the NetBird reverse proxy", + Args: cobra.ExactArgs(1), + Example: ` netbird expose --with-password safe-pass 8080 + netbird expose --protocol tcp 5432 + netbird expose --protocol tcp --with-external-port 5433 5432 + netbird expose --protocol tls --with-custom-domain tls.example.com 4443`, + RunE: exposeFn, +} + +func init() { + exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)") + exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)") + exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)") + exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)") + exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)") + exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)") + exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)") +} + +// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags. +func isClusterProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp", "tls": + return true + default: + return false + } +} + +// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP) +// where domain display doesn't apply. TLS uses SNI so it has a domain. +func isPortBasedProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp": + return true + default: + return false + } +} + +// extractPort returns the port portion of a URL like "tcp://host:12345", or +// falls back to the given default formatted as a string. +func extractPort(serviceURL string, fallback uint16) string { + u := serviceURL + if idx := strings.Index(u, "://"); idx != -1 { + u = u[idx+3:] + } + if i := strings.LastIndex(u, ":"); i != -1 { + if p := u[i+1:]; p != "" { + return p + } + } + return strconv.FormatUint(uint64(fallback), 10) +} + +// resolveExternalPort returns the effective external port, defaulting to the target port. +func resolveExternalPort(targetPort uint64) uint16 { + if exposeExternalPort != 0 { + return exposeExternalPort + } + return uint16(targetPort) +} + +func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { + port, err := strconv.ParseUint(portStr, 10, 32) + if err != nil { + return 0, fmt.Errorf("invalid port number: %s", portStr) + } + if port == 0 || port > 65535 { + return 0, fmt.Errorf("invalid port number: must be between 1 and 65535") + } + + if !isProtocolValid(exposeProtocol) { + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + } + + if isClusterProtocol(exposeProtocol) { + if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 { + return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol) + } + } else if cmd.Flags().Changed("with-external-port") { + return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol) + } + + if exposePin != "" && !pinRegexp.MatchString(exposePin) { + return 0, fmt.Errorf("invalid pin: must be exactly 6 digits") + } + + if cmd.Flags().Changed("with-password") && exposePassword == "" { + return 0, fmt.Errorf("password cannot be empty") + } + + if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 { + return 0, fmt.Errorf("user groups cannot be empty") + } + + return port, nil +} + +func isProtocolValid(exposeProtocol string) bool { + switch strings.ToLower(exposeProtocol) { + case "http", "https", "tcp", "udp", "tls": + return true + default: + return false + } +} + +func exposeFn(cmd *cobra.Command, args []string) error { + SetFlagsFromEnvVars(rootCmd) + + if err := util.InitLog(logLevel, util.LogConsole); err != nil { + log.Errorf("failed initializing log %v", err) + return err + } + + cmd.Root().SilenceUsage = false + + port, err := validateExposeFlags(cmd, args[0]) + if err != nil { + return err + } + + cmd.Root().SilenceUsage = true + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + cancel() + }() + + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return fmt.Errorf("connect to daemon: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("failed to close daemon connection: %v", err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + + protocol, err := toExposeProtocol(exposeProtocol) + if err != nil { + return err + } + + req := &proto.ExposeServiceRequest{ + Port: uint32(port), + Protocol: protocol, + Pin: exposePin, + Password: exposePassword, + UserGroups: exposeUserGroups, + Domain: exposeDomain, + NamePrefix: exposeNamePrefix, + } + if isClusterProtocol(exposeProtocol) { + req.ListenPort = uint32(resolveExternalPort(port)) + } + + stream, err := client.ExposeService(ctx, req) + if err != nil { + return fmt.Errorf("expose service: %v", status.Convert(err).Message()) + } + + if err := handleExposeReady(cmd, stream, port); err != nil { + return err + } + + return waitForExposeEvents(cmd, ctx, stream) +} + +func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { + p, err := expose.ParseProtocolType(exposeProtocol) + if err != nil { + return 0, fmt.Errorf("invalid protocol: %w", err) + } + + switch p { + case expose.ProtocolHTTP: + return proto.ExposeProtocol_EXPOSE_HTTP, nil + case expose.ProtocolHTTPS: + return proto.ExposeProtocol_EXPOSE_HTTPS, nil + case expose.ProtocolTCP: + return proto.ExposeProtocol_EXPOSE_TCP, nil + case expose.ProtocolUDP: + return proto.ExposeProtocol_EXPOSE_UDP, nil + case expose.ProtocolTLS: + return proto.ExposeProtocol_EXPOSE_TLS, nil + default: + return 0, fmt.Errorf("unhandled protocol type: %d", p) + } +} + +func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error { + event, err := stream.Recv() + if err != nil { + return fmt.Errorf("receive expose event: %v", status.Convert(err).Message()) + } + + ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready) + if !ok { + return fmt.Errorf("unexpected expose event: %T", event.Event) + } + printExposeReady(cmd, ready.Ready, port) + return nil +} + +func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) { + cmd.Println("Service exposed successfully!") + cmd.Printf(" Name: %s\n", r.ServiceName) + if r.ServiceUrl != "" { + cmd.Printf(" URL: %s\n", r.ServiceUrl) + } + if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) { + cmd.Printf(" Domain: %s\n", r.Domain) + } + cmd.Printf(" Protocol: %s\n", exposeProtocol) + cmd.Printf(" Internal: %d\n", port) + if isClusterProtocol(exposeProtocol) { + cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port))) + } + if r.PortAutoAssigned && exposeExternalPort != 0 { + cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort) + } + cmd.Println() + cmd.Println("Press Ctrl+C to stop exposing.") +} + +func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error { + for { + _, err := stream.Recv() + if err != nil { + if ctx.Err() != nil { + cmd.Println("\nService stopped.") + //nolint:nilerr + return nil + } + if errors.Is(err, io.EOF) { + return fmt.Errorf("connection to daemon closed unexpectedly") + } + return fmt.Errorf("stream error: %w", err) + } + } +} diff --git a/client/cmd/root.go b/client/cmd/root.go index f4f4f6052..c872fe9f6 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + daddr "github.com/netbirdio/netbird/client/internal/daemonaddr" "github.com/netbirdio/netbird/client/internal/profilemanager" ) @@ -74,12 +75,22 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + networksDisabled bool rootCmd = &cobra.Command{ Use: "netbird", Short: "", Long: "", SilenceUsage: true, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + SetFlagsFromEnvVars(cmd.Root()) + + // Don't resolve for service commands — they create the socket, not connect to it. + if !isServiceCmd(cmd) { + daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr) + } + return nil + }, } ) @@ -144,6 +155,7 @@ func init() { rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(profileCmd) + rootCmd.AddCommand(exposeCmd) networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) @@ -385,7 +397,6 @@ func migrateToNetbird(oldPath, newPath string) bool { } func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { - SetFlagsFromEnvVars(rootCmd) cmd.SetOut(cmd.OutOrStdout()) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) @@ -398,3 +409,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { return conn, nil } + +// isServiceCmd returns true if cmd is the "service" command or a child of it. +func isServiceCmd(cmd *cobra.Command) bool { + for c := cmd; c != nil; c = c.Parent() { + if c.Name() == "service" { + return true + } + } + return false +} diff --git a/client/cmd/service.go b/client/cmd/service.go index e55465875..f1123ce8c 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -41,13 +41,16 @@ func init() { defaultServiceName = "Netbird" } - serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` + + `New keys are merged with previously saved env vars; existing keys are overwritten. ` + + `Use --service-env "" to clear all saved env vars. ` + `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value` installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0545ce6b7..0943b6184 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } @@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error { // Common setup for service control commands func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) { - SetFlagsFromEnvVars(rootCmd) + // rootCmd env vars are already applied by PersistentPreRunE. SetFlagsFromEnvVars(serviceCmd) cmd.SetOut(cmd.OutOrStdout()) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index f6828d96a..5ada6f633 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,10 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if networksDisabled { + args = append(args, "--disable-networks") + } + return args } @@ -119,6 +123,10 @@ var installCmd = &cobra.Command{ return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + svcConfig, err := createServiceConfigForInstall() if err != nil { return err @@ -136,6 +144,10 @@ var installCmd = &cobra.Command{ return fmt.Errorf("install service: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + cmd.Println("NetBird service has been installed") return nil }, @@ -187,6 +199,10 @@ This command will temporarily stop the service, update its configuration, and re return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + wasRunning, err := isServiceRunning() if err != nil && !errors.Is(err, ErrGetServiceStatus) { return fmt.Errorf("check service status: %w", err) @@ -222,6 +238,10 @@ This command will temporarily stop the service, update its configuration, and re return fmt.Errorf("install service with new config: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + if wasRunning { cmd.Println("Starting NetBird service...") if err := s.Start(); err != nil { diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go new file mode 100644 index 000000000..5a86aebc6 --- /dev/null +++ b/client/cmd/service_params.go @@ -0,0 +1,218 @@ +//go:build !ios && !android + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "os" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/configs" + "github.com/netbirdio/netbird/util" +) + +const serviceParamsFile = "service.json" + +// serviceParams holds install-time service parameters that persist across +// uninstall/reinstall cycles. Saved to /service.json. +type serviceParams struct { + LogLevel string `json:"log_level"` + DaemonAddr string `json:"daemon_addr"` + ManagementURL string `json:"management_url,omitempty"` + ConfigPath string `json:"config_path,omitempty"` + LogFiles []string `json:"log_files,omitempty"` + DisableProfiles bool `json:"disable_profiles,omitempty"` + DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + DisableNetworks bool `json:"disable_networks,omitempty"` + ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` +} + +// serviceParamsPath returns the path to the service params file. +func serviceParamsPath() string { + return filepath.Join(configs.StateDir, serviceParamsFile) +} + +// loadServiceParams reads saved service parameters from disk. +// Returns nil with no error if the file does not exist. +func loadServiceParams() (*serviceParams, error) { + path := serviceParamsPath() + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil //nolint:nilnil + } + return nil, fmt.Errorf("read service params %s: %w", path, err) + } + + var params serviceParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return nil, fmt.Errorf("parse service params %s: %w", path, err) + } + + return ¶ms, nil +} + +// saveServiceParams writes current service parameters to disk atomically +// with restricted permissions. +func saveServiceParams(params *serviceParams) error { + path := serviceParamsPath() + if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil { + return fmt.Errorf("save service params: %w", err) + } + return nil +} + +// currentServiceParams captures the current state of all package-level +// variables into a serviceParams struct. +func currentServiceParams() *serviceParams { + params := &serviceParams{ + LogLevel: logLevel, + DaemonAddr: daemonAddr, + ManagementURL: managementURL, + ConfigPath: configPath, + LogFiles: logFiles, + DisableProfiles: profilesDisabled, + DisableUpdateSettings: updateSettingsDisabled, + DisableNetworks: networksDisabled, + } + + if len(serviceEnvVars) > 0 { + parsed, err := parseServiceEnvVars(serviceEnvVars) + if err == nil { + params.ServiceEnvVars = parsed + } + } + + return params +} + +// loadAndApplyServiceParams loads saved params from disk and applies them +// to any flags that were not explicitly set. +func loadAndApplyServiceParams(cmd *cobra.Command) error { + params, err := loadServiceParams() + if err != nil { + return err + } + applyServiceParams(cmd, params) + return nil +} + +// applyServiceParams merges saved parameters into package-level variables +// for any flag that was not explicitly set by the user (via CLI or env var). +// Flags that were Changed() are left untouched. +func applyServiceParams(cmd *cobra.Command, params *serviceParams) { + if params == nil { + return + } + + // For fields with non-empty defaults (log-level, daemon-addr), keep the + // != "" guard so that an older service.json missing the field doesn't + // clobber the default with an empty string. + if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" { + logLevel = params.LogLevel + } + + if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" { + daemonAddr = params.DaemonAddr + } + + // For optional fields where empty means "use default", always apply so + // that an explicit clear (--management-url "") persists across reinstalls. + if !rootCmd.PersistentFlags().Changed("management-url") { + managementURL = params.ManagementURL + } + + if !rootCmd.PersistentFlags().Changed("config") { + configPath = params.ConfigPath + } + + if !rootCmd.PersistentFlags().Changed("log-file") { + logFiles = params.LogFiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-profiles") { + profilesDisabled = params.DisableProfiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-update-settings") { + updateSettingsDisabled = params.DisableUpdateSettings + } + + if !serviceCmd.PersistentFlags().Changed("disable-networks") { + networksDisabled = params.DisableNetworks + } + + applyServiceEnvParams(cmd, params) +} + +// applyServiceEnvParams merges saved service environment variables. +// If --service-env was explicitly set with values, explicit values win on key +// conflict but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set to empty, all saved env vars are cleared. +// If --service-env was not set, saved env vars are used entirely. +func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) { + if !cmd.Flags().Changed("service-env") { + if len(params.ServiceEnvVars) > 0 { + // No explicit env vars: rebuild serviceEnvVars from saved params. + serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + } + return + } + + // Flag was explicitly set: parse what the user provided. + explicit, err := parseServiceEnvVars(serviceEnvVars) + if err != nil { + cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err) + return + } + + // If the user passed an empty value (e.g. --service-env ""), clear all + // saved env vars rather than merging. + if len(explicit) == 0 { + serviceEnvVars = nil + return + } + + if len(params.ServiceEnvVars) == 0 { + return + } + + // Merge saved values underneath explicit ones. + merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit)) + maps.Copy(merged, params.ServiceEnvVars) + maps.Copy(merged, explicit) // explicit wins on conflict + serviceEnvVars = envMapToSlice(merged) +} + +var resetParamsCmd = &cobra.Command{ + Use: "reset-params", + Short: "Remove saved service install parameters", + Long: "Removes the saved service.json file so the next install uses default parameters.", + RunE: func(cmd *cobra.Command, args []string) error { + path := serviceParamsPath() + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + cmd.Println("No saved service parameters found") + return nil + } + return fmt.Errorf("remove service params: %w", err) + } + cmd.Printf("Removed saved service parameters (%s)\n", path) + return nil + }, +} + +// envMapToSlice converts a map of env vars to a KEY=VALUE slice. +func envMapToSlice(m map[string]string) []string { + s := make([]string, 0, len(m)) + for k, v := range m { + s = append(s, k+"="+v) + } + return s +} diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go new file mode 100644 index 000000000..7e04e5abe --- /dev/null +++ b/client/cmd/service_params_test.go @@ -0,0 +1,559 @@ +//go:build !ios && !android + +package cmd + +import ( + "encoding/json" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/configs" +) + +func TestServiceParamsPath(t *testing.T) { + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + + configs.StateDir = "/var/lib/netbird" + assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath()) + + configs.StateDir = "/custom/state" + assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath()) +} + +func TestSaveAndLoadServiceParams(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "unix:///var/run/netbird.sock", + ManagementURL: "https://my.server.com", + ConfigPath: "/etc/netbird/config.json", + LogFiles: []string{"/var/log/netbird/client.log", "console"}, + DisableProfiles: true, + DisableUpdateSettings: false, + ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"}, + } + + err := saveServiceParams(params) + require.NoError(t, err) + + // Verify the file exists and is valid JSON. + data, err := os.ReadFile(filepath.Join(tmpDir, "service.json")) + require.NoError(t, err) + assert.True(t, json.Valid(data)) + + loaded, err := loadServiceParams() + require.NoError(t, err) + require.NotNil(t, loaded) + + assert.Equal(t, params.LogLevel, loaded.LogLevel) + assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr) + assert.Equal(t, params.ManagementURL, loaded.ManagementURL) + assert.Equal(t, params.ConfigPath, loaded.ConfigPath) + assert.Equal(t, params.LogFiles, loaded.LogFiles) + assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles) + assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings) + assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars) +} + +func TestLoadServiceParams_FileNotExists(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params, err := loadServiceParams() + assert.NoError(t, err) + assert.Nil(t, params) +} + +func TestLoadServiceParams_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600) + require.NoError(t, err) + + params, err := loadServiceParams() + assert.Error(t, err) + assert.Nil(t, params) +} + +func TestCurrentServiceParams(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + logLevel = "trace" + daemonAddr = "tcp://127.0.0.1:9999" + managementURL = "https://mgmt.example.com" + configPath = "/tmp/test-config.json" + logFiles = []string{"/tmp/test.log"} + profilesDisabled = true + updateSettingsDisabled = true + serviceEnvVars = []string{"FOO=bar", "BAZ=qux"} + + params := currentServiceParams() + + assert.Equal(t, "trace", params.LogLevel) + assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr) + assert.Equal(t, "https://mgmt.example.com", params.ManagementURL) + assert.Equal(t, "/tmp/test-config.json", params.ConfigPath) + assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles) + assert.True(t, params.DisableProfiles) + assert.True(t, params.DisableUpdateSettings) + assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars) +} + +func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + // Reset all flags to defaults. + logLevel = "info" + daemonAddr = "unix:///var/run/netbird.sock" + managementURL = "" + configPath = "/etc/netbird/config.json" + logFiles = []string{"/var/log/netbird/client.log"} + profilesDisabled = false + updateSettingsDisabled = false + serviceEnvVars = nil + + // Reset Changed state on all relevant flags. + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Simulate user explicitly setting --log-level via CLI. + logLevel = "warn" + require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn")) + + saved := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "tcp://127.0.0.1:5555", + ManagementURL: "https://saved.example.com", + ConfigPath: "/saved/config.json", + LogFiles: []string{"/saved/client.log"}, + DisableProfiles: true, + DisableUpdateSettings: true, + ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"}, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + // log-level was Changed, so it should keep "warn", not use saved "debug". + assert.Equal(t, "warn", logLevel) + + // All other fields were not Changed, so they should use saved values. + assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr) + assert.Equal(t, "https://saved.example.com", managementURL) + assert.Equal(t, "/saved/config.json", configPath) + assert.Equal(t, []string{"/saved/client.log"}, logFiles) + assert.True(t, profilesDisabled) + assert.True(t, updateSettingsDisabled) + assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars) +} + +func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) { + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + t.Cleanup(func() { + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + }) + + // Simulate current state where booleans are true (e.g. set by previous install). + profilesDisabled = true + updateSettingsDisabled = true + + // Reset Changed state so flags appear unset. + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Saved params have both as false. + saved := &serviceParams{ + DisableProfiles: false, + DisableUpdateSettings: false, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.False(t, profilesDisabled, "saved false should override current true") + assert.False(t, updateSettingsDisabled, "saved false should override current true") +} + +func TestApplyServiceParams_ClearManagementURL(t *testing.T) { + origManagementURL := managementURL + t.Cleanup(func() { managementURL = origManagementURL }) + + managementURL = "https://leftover.example.com" + + // Simulate saved params where management URL was explicitly cleared. + saved := &serviceParams{ + LogLevel: "info", + DaemonAddr: "unix:///var/run/netbird.sock", + // ManagementURL intentionally empty: was cleared with --management-url "". + } + + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value") +} + +func TestApplyServiceParams_NilParams(t *testing.T) { + origLogLevel := logLevel + t.Cleanup(func() { logLevel = origLogLevel }) + + logLevel = "info" + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + // Should be a no-op. + applyServiceParams(cmd, nil) + assert.Equal(t, "info", logLevel) +} + +func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Set up a command with --service-env marked as Changed. + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit")) + + serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"} + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{ + "SAVED": "val", + "OVERLAP": "saved", + }, + } + + applyServiceEnvParams(cmd, saved) + + // Parse result for easier assertion. + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + + assert.Equal(t, "yes", result["EXPLICIT"]) + assert.Equal(t, "val", result["SAVED"]) + // Explicit wins on conflict. + assert.Equal(t, "explicit", result["OVERLAP"]) +} + +func TestApplyServiceEnvParams_NotChanged(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + serviceEnvVars = nil + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"FROM_SAVED": "val"}, + } + + applyServiceEnvParams(cmd, saved) + + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result) +} + +func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "")) + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"}, + } + + applyServiceEnvParams(cmd, saved) + + assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars") +} + +func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + params := currentServiceParams() + + // After parsing, the empty string is skipped, resulting in an empty map. + // The map should still be set (not nil) so it overwrites saved values. + assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil") + assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string") +} + +// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are +// referenced in both currentServiceParams() and applyServiceParams(). If a new field is +// added to serviceParams but not wired into these functions, this test fails. +func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + // Collect all JSON field names from the serviceParams struct. + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields, "failed to find serviceParams struct fields") + + // Collect field names referenced in currentServiceParams and applyServiceParams. + currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields) + applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields) + // applyServiceEnvParams handles ServiceEnvVars indirectly. + applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields) + for k, v := range applyEnvFields { + applyFields[k] = v + } + + for _, field := range structFields { + assert.Contains(t, currentFields, field, + "serviceParams field %q is not captured in currentServiceParams()", field) + assert.Contains(t, applyFields, field, + "serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field) + } +} + +// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references +// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because +// it flows through newSVCConfig() EnvVars, not CLI args. +func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields) + + installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0) + require.NoError(t, err) + + // Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig). + fieldsNotInArgs := map[string]bool{ + "ServiceEnvVars": true, + } + + buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments") + + // Forward: every struct field must appear in buildServiceArguments. + for _, field := range structFields { + if fieldsNotInArgs[field] { + continue + } + globalVar := fieldToGlobalVar(field) + assert.Contains(t, buildFields, globalVar, + "serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar) + } + + // Reverse: every service-related global used in buildServiceArguments must + // have a corresponding serviceParams field. This catches a developer adding + // a new flag to buildServiceArguments without adding it to the struct. + globalToField := make(map[string]string, len(structFields)) + for _, field := range structFields { + globalToField[fieldToGlobalVar(field)] = field + } + // Identifiers in buildServiceArguments that are not service params + // (builtins, boilerplate, loop variables). + nonParamGlobals := map[string]bool{ + "args": true, "append": true, "string": true, "_": true, + "logFile": true, // range variable over logFiles + } + for ref := range buildFields { + if nonParamGlobals[ref] { + continue + } + _, inStruct := globalToField[ref] + assert.True(t, inStruct, + "buildServiceArguments() references global %q which has no corresponding serviceParams field", ref) + } +} + +// extractStructJSONFields returns field names from a named struct type. +func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string { + t.Helper() + var fields []string + ast.Inspect(file, func(n ast.Node) bool { + ts, ok := n.(*ast.TypeSpec) + if !ok || ts.Name.Name != structName { + return true + } + st, ok := ts.Type.(*ast.StructType) + if !ok { + return false + } + for _, f := range st.Fields.List { + if len(f.Names) > 0 { + fields = append(fields, f.Names[0].Name) + } + } + return false + }) + return fields +} + +// extractFuncFieldRefs returns which of the given field names appear inside the +// named function, either as selector expressions (params.FieldName) or as +// composite literal keys (&serviceParams{FieldName: ...}). +func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool { + t.Helper() + fieldSet := make(map[string]bool, len(fields)) + for _, f := range fields { + fieldSet[f] = true + } + + found := make(map[string]bool) + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + ast.Inspect(fn.Body, func(n ast.Node) bool { + switch v := n.(type) { + case *ast.SelectorExpr: + if fieldSet[v.Sel.Name] { + found[v.Sel.Name] = true + } + case *ast.KeyValueExpr: + if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] { + found[ident.Name] = true + } + } + return true + }) + return found +} + +// extractFuncGlobalRefs returns all identifier names referenced in the named function body. +func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool { + t.Helper() + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + refs := make(map[string]bool) + ast.Inspect(fn.Body, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + refs[ident.Name] = true + } + return true + }) + return refs +} + +func findFuncDecl(file *ast.File, name string) *ast.FuncDecl { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if ok && fn.Name.Name == name { + return fn + } + } + return nil +} + +// fieldToGlobalVar maps serviceParams field names to the package-level variable +// names used in buildServiceArguments and applyServiceParams. +func fieldToGlobalVar(field string) string { + m := map[string]string{ + "LogLevel": "logLevel", + "DaemonAddr": "daemonAddr", + "ManagementURL": "managementURL", + "ConfigPath": "configPath", + "LogFiles": "logFiles", + "DisableProfiles": "profilesDisabled", + "DisableUpdateSettings": "updateSettingsDisabled", + "DisableNetworks": "networksDisabled", + "ServiceEnvVars": "serviceEnvVars", + } + if v, ok := m[field]; ok { + return v + } + // Default: lowercase first letter. + return strings.ToLower(field[:1]) + field[1:] +} + +func TestEnvMapToSlice(t *testing.T) { + m := map[string]string{"A": "1", "B": "2"} + s := envMapToSlice(m) + assert.Len(t, s, 2) + assert.Contains(t, s, "A=1") + assert.Contains(t, s, "B=2") +} + +func TestEnvMapToSlice_Empty(t *testing.T) { + s := envMapToSlice(map[string]string{}) + assert.Empty(t, s) +} diff --git a/client/cmd/service_test.go b/client/cmd/service_test.go index 6d75ca524..ce6f71550 100644 --- a/client/cmd/service_test.go +++ b/client/cmd/service_test.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "os" + "os/signal" "runtime" + "syscall" "testing" "time" @@ -13,6 +15,22 @@ import ( "github.com/stretchr/testify/require" ) +// TestMain intercepts when this test binary is run as a daemon subprocess. +// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with +// "service run ..." arguments. Since the test binary can't handle cobra CLI +// args, it exits immediately, causing daemon -r to respawn rapidly until +// hitting the rate limit and exiting. This makes service restart unreliable. +// Blocking here keeps the subprocess alive until the init system sends SIGTERM. +func TestMain(m *testing.M) { + if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, os.Interrupt) + <-sig + return + } + os.Exit(m.Run()) +} + const ( serviceStartTimeout = 10 * time.Second serviceStopTimeout = 5 * time.Second @@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) { logLevel = "info" daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir) + // Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run. + t.Cleanup(func() { + cfg, err := newSVCConfig() + if err != nil { + t.Errorf("cleanup: create service config: %v", err) + return + } + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + if err != nil { + t.Errorf("cleanup: create service: %v", err) + return + } + + // If the subtests already cleaned up, there's nothing to do. + if _, err := s.Status(); err != nil { + return + } + + if err := s.Stop(); err != nil { + t.Errorf("cleanup: stop service: %v", err) + } + if err := s.Uninstall(); err != nil { + t.Errorf("cleanup: uninstall service: %v", err) + } + }) + ctx := context.Background() t.Run("Install", func(t *testing.T) { diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go index 5e656650b..ee12326db 100644 --- a/client/cmd/signer/artifactkey.go +++ b/client/cmd/signer/artifactkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go index 881be9367..7c02323dc 100644 --- a/client/cmd/signer/artifactsign.go +++ b/client/cmd/signer/artifactsign.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go index 1d84b65c3..5ff636dcb 100644 --- a/client/cmd/signer/revocation.go +++ b/client/cmd/signer/revocation.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go index 78ac36b41..eae0da84d 100644 --- a/client/cmd/signer/rootkey.go +++ b/client/cmd/signer/rootkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/status.go b/client/cmd/status.go index f09c35c2c..c35a06eb3 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -28,6 +28,7 @@ var ( ipsFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{} connectionTypeFilter string + checkFlag string ) var statusCmd = &cobra.Command{ @@ -49,6 +50,7 @@ func init() { statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") + statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") } func statusFunc(cmd *cobra.Command, args []string) error { @@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error { cmd.SetOut(cmd.OutOrStdout()) + if checkFlag != "" { + return runHealthCheck(cmd) + } + err := parseFilters() if err != nil { return err @@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx, false) + resp, err := getStatus(ctx, true, false) if err != nil { return err } status := resp.GetStatus() - if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || - status == string(internal.StatusSessionExpired) { + needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || + status == string(internal.StatusSessionExpired) + + if needsAuth && !jsonFlag && !yamlFlag { cmd.Printf("Daemon status: %s\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+ " netbird up \n\n"+ @@ -99,7 +107,17 @@ func statusFunc(cmd *cobra.Command, args []string) error { profName = activeProf.Name } - var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{ + Anonymize: anonymizeFlag, + DaemonVersion: resp.GetDaemonVersion(), + DaemonStatus: nbstatus.ParseDaemonStatus(status), + StatusFilter: statusFilter, + PrefixNamesFilter: prefixNamesFilter, + PrefixNamesFilterMap: prefixNamesFilterMap, + IPsFilter: ipsFilterMap, + ConnectionTypeFilter: connectionTypeFilter, + ProfileName: profName, + }) var statusOutputString string switch { case detailFlag: @@ -121,7 +139,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { //nolint @@ -131,7 +149,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } @@ -185,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() { } } +func runHealthCheck(cmd *cobra.Command) error { + check := strings.ToLower(checkFlag) + switch check { + case "live", "ready", "startup": + default: + return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag) + } + + if err := util.InitLog(logLevel, util.LogConsole); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := internal.CtxInitState(cmd.Context()) + + isStartup := check == "startup" + resp, err := getStatus(ctx, isStartup, false) + if err != nil { + return err + } + + switch check { + case "live": + return nil + case "ready": + return checkReadiness(resp) + case "startup": + return checkStartup(resp) + default: + return nil + } +} + +func checkReadiness(resp *proto.StatusResponse) error { + daemonStatus := internal.StatusType(resp.GetStatus()) + switch daemonStatus { + case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected: + return nil + case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired: + return fmt.Errorf("readiness check: daemon status is %s", daemonStatus) + default: + return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus) + } +} + +func checkStartup(resp *proto.StatusResponse) error { + fullStatus := resp.GetFullStatus() + if fullStatus == nil { + return fmt.Errorf("startup check: no full status available") + } + + if !fullStatus.GetManagementState().GetConnected() { + return fmt.Errorf("startup check: management not connected") + } + + if !fullStatus.GetSignalState().GetConnected() { + return fmt.Errorf("startup check: signal not connected") + } + + var relayCount, relaysConnected int + for _, r := range fullStatus.GetRelays() { + uri := r.GetURI() + if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") { + continue + } + relayCount++ + if r.GetAvailable() { + relaysConnected++ + } + } + + if relayCount > 0 && relaysConnected == 0 { + return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount) + } + + return nil +} + func parseInterfaceIP(interfaceIP string) string { ip, _, err := net.ParseCIDR(interfaceIP) if err != nil { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 4bda33e65..d7564c353 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" @@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp jobManager := job.NewJobManager(nil, store, peersmanager) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -152,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false) + "", "", false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 9559287d5..f5766522a 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r, false) + connectClient := internal.NewConnectClient(ctx, config, r) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go index 977875093..0b197f4c5 100644 --- a/client/cmd/update_supported.go +++ b/client/cmd/update_supported.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/util" ) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4fbe0eada..88f7e541c 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -14,6 +14,7 @@ import ( "github.com/sirupsen/logrus" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" @@ -21,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -31,14 +33,14 @@ var ( ErrConfigNotInitialized = errors.New("config not initialized") ) -// PeerConnStatus is a peer's connection status. -type PeerConnStatus = peer.ConnStatus - const ( // PeerStatusConnected indicates the peer is in connected state. PeerStatusConnected = peer.StatusConnected ) +// PeerConnStatus is a peer's connection status. +type PeerConnStatus = peer.ConnStatus + // Client manages a netbird embedded client instance. type Client struct { deviceName string @@ -81,6 +83,14 @@ type Options struct { BlockInbound bool // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. WireguardPort *int + // MTU is the MTU for the WireGuard interface. + // Valid values are in the range 576..8192 bytes. + // If non-nil, this value overrides any value stored in the config file. + // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. + // Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams. + MTU *uint16 + // DNSLabels defines additional DNS labels configured in the peer. + DNSLabels []string } // validateCredentials checks that exactly one credential type is provided @@ -112,6 +122,12 @@ func New(opts Options) (*Client, error) { return nil, err } + if opts.MTU != nil { + if err := iface.ValidateMTU(*opts.MTU); err != nil { + return nil, fmt.Errorf("invalid MTU: %w", err) + } + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -140,9 +156,14 @@ func New(opts Options) (*Client, error) { } } + var err error + var parsedLabels domain.List + if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil { + return nil, fmt.Errorf("invalid dns labels: %w", err) + } + t := true var config *profilemanager.Config - var err error input := profilemanager.ConfigInput{ ConfigPath: opts.ConfigPath, ManagementURL: opts.ManagementURL, @@ -151,6 +172,8 @@ func New(opts Options) (*Client, error) { DisableClientRoutes: &opts.DisableClientRoutes, BlockInbound: &opts.BlockInbound, WireguardPort: opts.WireguardPort, + MTU: opts.MTU, + DNSLabels: parsedLabels, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) @@ -202,7 +225,7 @@ func (c *Client) Start(startCtx context.Context) error { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } - client := internal.NewConnectClient(ctx, c.config, c.recorder, false) + client := internal.NewConnectClient(ctx, c.config, c.recorder) client.SetSyncResponsePersistence(true) // either startup error (permanent backoff err) or nil err (successful engine up) @@ -352,6 +375,32 @@ func (c *Client) NewHTTPClient() *http.Client { } } +// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL. +// It returns an ExposeSession. Call Wait on the session to keep it alive. +func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + mgr := engine.GetExposeManager() + if mgr == nil { + return nil, fmt.Errorf("expose manager not available") + } + + resp, err := mgr.Expose(ctx, req) + if err != nil { + return nil, fmt.Errorf("expose: %w", err) + } + + return &ExposeSession{ + Domain: resp.Domain, + ServiceName: resp.ServiceName, + ServiceURL: resp.ServiceURL, + mgr: mgr, + }, nil +} + // Status returns the current status of the client. func (c *Client) Status() (peer.FullStatus, error) { c.mu.Lock() diff --git a/client/embed/expose.go b/client/embed/expose.go new file mode 100644 index 000000000..825bb90ee --- /dev/null +++ b/client/embed/expose.go @@ -0,0 +1,45 @@ +package embed + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/client/internal/expose" +) + +const ( + // ExposeProtocolHTTP exposes the service as HTTP. + ExposeProtocolHTTP = expose.ProtocolHTTP + // ExposeProtocolHTTPS exposes the service as HTTPS. + ExposeProtocolHTTPS = expose.ProtocolHTTPS + // ExposeProtocolTCP exposes the service as TCP. + ExposeProtocolTCP = expose.ProtocolTCP + // ExposeProtocolUDP exposes the service as UDP. + ExposeProtocolUDP = expose.ProtocolUDP + // ExposeProtocolTLS exposes the service as TLS. + ExposeProtocolTLS = expose.ProtocolTLS +) + +// ExposeRequest is a request to expose a local service via the NetBird reverse proxy. +type ExposeRequest = expose.Request + +// ExposeProtocolType represents the protocol used for exposing a service. +type ExposeProtocolType = expose.ProtocolType + +// ExposeSession represents an active expose session. Use Wait to block until the session ends. +type ExposeSession struct { + Domain string + ServiceName string + ServiceURL string + + mgr *expose.Manager +} + +// Wait blocks while keeping the expose session alive. +// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session. +func (s *ExposeSession) Wait(ctx context.Context) error { + if s == nil || s.mgr == nil { + return errors.New("expose session is not initialized") + } + return s.mgr.KeepAlive(ctx, s.Domain) +} diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 12dcaee8a..d916ebad4 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "strconv" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" @@ -35,20 +36,34 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" type FWType int func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { - // on the linux system we try to user nftables or iptables - // in any case, because we need to allow netbird interface traffic - // so we use AllowNetbird traffic from these firewall managers - // for the userspace packet filtering firewall + // We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall. + if iface.IsUserspaceBind() && forceUserspaceFirewall() { + log.Info("forcing userspace firewall") + return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) + } + + // Use native firewall for either kernel or userspace, the interface appears identical to netfilter fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) + // Kernel cannot fall back to anything else, need to return error if !iface.IsUserspaceBind() { return fm, err } + // Fall back to the userspace packet filter if native is unavailable if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) + return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) + + // Native firewall handles packet filtering, but the userspace WireGuard bind + // needs a device filter for DNS interception hooks. Install a minimal + // hooks-only filter that passes all traffic through to the kernel firewall. + if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil { + log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err) + } + + return fm, nil } func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { @@ -160,3 +175,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool { _, err := client.ListChains("filter") return err == nil } + +func forceUserspaceFirewall() bool { + val := os.Getenv(EnvForceUserspaceFirewall) + if val == "" { + return false + } + + force, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err) + return false + } + return force +} diff --git a/client/firewall/iface.go b/client/firewall/iface.go index b83c5f912..491f03269 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -7,6 +7,12 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" ) +// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when +// native iptables/nftables is available. This only applies when the WireGuard interface +// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of +// kernel netfilter rules. +const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL" + // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f09..e629f7881 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -21,6 +21,10 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" + + // mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent + // external DNAT from bypassing ACL rules. + mangleFwdKey = "MANGLE-FORWARD" ) type aclEntries map[string][][]string @@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error { } } + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err) + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := m.flushIPSet(ipsetName); err != nil { if errors.Is(err, ipset.ErrSetNotExist) { @@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error { } for chainName, rules := range m.entries { + // mangle FORWARD guard rules are handled separately below + if chainName == mangleFwdKey { + continue + } for _, rule := range rules { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { log.Debugf("failed to create input chain jump rule: %s", err) @@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error { } clear(m.optionalEntries) + // Insert mangle FORWARD guard rules to prevent external DNAT bypass. + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to add mangle FORWARD guard rule: %v", err) + } + } + return nil } @@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) + + // Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it + // traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD + // can be inserted above ours. Mangle runs before filter, so these guard rules enforce the + // ACL mark check where it cannot be overridden. + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", + "-j", "ACCEPT", + }) + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "DNAT", + "-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), + "-j", "DROP", + }) } func (m *aclManager) seedInitialOptionalEntries() { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 716385705..a1d4467d5 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -23,16 +23,16 @@ type Manager struct { wgIface iFaceMapper - ipv4Client *iptables.IPTables - aclMgr *aclManager - router *router + ipv4Client *iptables.IPTables + aclMgr *aclManager + router *router + rawSupported bool } // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string Address() wgaddr.Address - IsUserspaceBind() bool } // Create iptables firewall manager @@ -63,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ - NameStr: m.wgIface.Name(), - WGAddress: m.wgIface.Address(), - UserspaceBind: m.wgIface.IsUserspaceBind(), - MTU: m.router.mtu, + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) @@ -84,7 +83,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.initNoTrackChain(); err != nil { - return fmt.Errorf("init notrack chain: %w", err) + log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } // persist early to ensure cleanup of chains @@ -202,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nberrors.FormatErrorOrNil(merr) } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// This is called when USPFilter wraps the native firewall, adding blanket accept +// rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - if !m.wgIface.IsUserspaceBind() { - return nil - } - _, err := m.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -285,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRaw = "NETBIRD-RAW" chainOUTPUT = "OUTPUT" @@ -318,6 +331,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() + if !m.rawSupported { + return fmt.Errorf("raw table not available") + } + wgPortStr := fmt.Sprintf("%d", wgPort) proxyPortStr := fmt.Sprintf("%d", proxyPort) @@ -375,12 +392,16 @@ func (m *Manager) initNoTrackChain() error { return fmt.Errorf("add prerouting jump rule: %w", err) } + m.rawSupported = true return nil } func (m *Manager) cleanupNoTrackChain() error { exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw) if err != nil { + if !m.rawSupported { + return nil + } return fmt.Errorf("check chain exists: %w", err) } if !exists { @@ -401,6 +422,7 @@ func (m *Manager) cleanupNoTrackChain() error { return fmt.Errorf("clear and delete chain: %w", err) } + m.rawSupported = false return nil } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ee47a27c0..cc4bda0e0 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address { panic("AddressFunc is not set") } -func (i *iFaceMock) IsUserspaceBind() bool { return false } - func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1fe4c149f..a7c4f67dd 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -36,6 +36,7 @@ const ( chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainNATOutput = "NETBIRD-NAT-OUTPUT" chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" @@ -43,6 +44,7 @@ const ( jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpNatOutput = "jump-nat-output" jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" @@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error { } log.Debug("flushing routing related tables") + + // Remove jump rules from built-in chains before deleting custom chains, + // otherwise the chain deletion fails with "device or resource busy". + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { + log.Debugf("clean OUTPUT jump rule: %v", err) + } + for _, chainInfo := range []struct { chain string table string @@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainNATOutput, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) @@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.rules[jumpNatOutput]; exists { + return nil + } + + chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput) + if err != nil { + return fmt.Errorf("check chain %s: %w", chainNATOutput, err) + } + if !chainExists { + if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("create chain %s: %w", chainNATOutput, err) + } + } + + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil { + if !chainExists { + if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil { + log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr) + } + } + return fmt.Errorf("add OUTPUT jump rule: %w", err) + } + r.rules[jumpNatOutput] = jumpRule + + r.updateState() + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + dnatRule := []string{ + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + r.rules[ruleID] = dnatRule + + r.updateState() + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index c88774c1f..121c755e9 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -9,10 +9,9 @@ import ( ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress wgaddr.Address `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` - MTU uint16 `json:"mtu"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } -func (i *InterfaceState) IsUserspaceBind() bool { - return i.UserspaceBind -} - type ShutdownState struct { sync.Mutex diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3511a5463..d65d717b3 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -169,6 +169,14 @@ type Manager interface { // RemoveInboundDNAT removes inbound DNAT rule RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. + AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. + RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // This prevents conntrack from interfering with WireGuard proxy communication. SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index acf482f86..0b5b61e04 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -40,7 +40,6 @@ func getTableName() string { type iFaceMapper interface { Name() string Address() wgaddr.Address - IsUserspaceBind() bool } // Manager of iptables firewall @@ -95,7 +94,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.initNoTrackChains(workTable); err != nil { - return fmt.Errorf("init notrack chains: %w", err) + log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } stateManager.RegisterState(&ShutdownState{}) @@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ - NameStr: m.wgIface.Name(), - WGAddress: m.wgIface.Address(), - UserspaceBind: m.wgIface.IsUserspaceBind(), - MTU: m.router.mtu, + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) @@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { return m.router.RemoveNatRule(pair) } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// This is called when USPFilter wraps the native firewall, adding blanket accept +// rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - if !m.wgIface.IsUserspaceBind() { - return nil - } - m.mutex.Lock() defer m.mutex.Unlock() @@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRawOutput = "netbird-raw-out" chainNameRawPrerouting = "netbird-raw-pre" diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 75b1e2b6c..d48e4ba88 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address { panic("AddressFunc is not set") } -func (i *iFaceMock) IsUserspaceBind() bool { return false } - func TestNftablesManager(t *testing.T) { // just check on the local interface diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index fde654c20..904daf7cb 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -36,6 +36,7 @@ const ( chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" + chainNameNATOutput = "netbird-nat-output" chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" @@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.chains[chainNameNATOutput]; exists { + return nil + } + + r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameNATOutput, + Table: r.workTable, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityNATDest, + Type: nftables.ChainTypeNAT, + }) + + if err := r.conn.Flush(); err != nil { + delete(r.chains, chainNameNATOutput) + return fmt.Errorf("create NAT output chain: %w", err) + } + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameNATOutput], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + rule, exists := r.rules[ruleID] + if !exists { + return nil + } + + if rule.Handle == 0 { + log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID) + delete(r.rules, ruleID) + return nil + } + + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index 48b7b3741..462ad2556 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -8,10 +8,9 @@ import ( ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress wgaddr.Address `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` - MTU uint16 `json:"mtu"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } -func (i *InterfaceState) IsUserspaceBind() bool { - return i.UserspaceBind -} - type ShutdownState struct { InterfaceState *InterfaceState `json:"interface_state,omitempty"` } diff --git a/client/firewall/uspfilter/common/hooks.go b/client/firewall/uspfilter/common/hooks.go new file mode 100644 index 000000000..dadd800dd --- /dev/null +++ b/client/firewall/uspfilter/common/hooks.go @@ -0,0 +1,37 @@ +package common + +import ( + "net/netip" + "sync/atomic" +) + +// PacketHook stores a registered hook for a specific IP:port. +type PacketHook struct { + IP netip.Addr + Port uint16 + Fn func([]byte) bool +} + +// HookMatches checks if a packet's destination matches the hook and invokes it. +func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool { + if h == nil { + return false + } + if h.IP == dstIP && h.Port == dport { + return h.Fn(packetData) + } + return false +} + +// SetHook atomically stores a hook, handling nil removal. +func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) { + if hook == nil { + ptr.Store(nil) + return + } + ptr.Store(&PacketHook{ + IP: ip, + Port: dPort, + Fn: hook, + }) +} diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index df2e274eb..24b3d0167 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -140,6 +140,10 @@ type Manager struct { mtu uint16 mssClampValue uint16 mssClampEnabled bool + + // Only one hook per protocol is supported. Outbound direction only. + udpHookOut atomic.Pointer[common.PacketHook] + tcpHookOut atomic.Pointer[common.PacketHook] } // decoder for packages @@ -594,6 +598,8 @@ func (m *Manager) resetState() { maps.Clear(m.incomingRules) maps.Clear(m.routeRulesMap) m.routeRules = m.routeRules[:0] + m.udpHookOut.Store(nil) + m.tcpHookOut.Store(nil) if m.udpTracker != nil { m.udpTracker.Close() @@ -713,6 +719,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } case layers.LayerTypeTCP: + if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) { + return true + } // Clamp MSS on all TCP SYN packets, including those from local IPs. // SNATed routed traffic may appear as local IP but still requires clamping. if m.mssClampEnabled { @@ -895,39 +904,12 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt d.dnatOrigPort = 0 } -// udpHooksDrop checks if any UDP hooks should drop the packet func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() + return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) +} - // Check specific destination IP first - if rules, exists := m.outgoingRules[dstIP]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - // Check IPv4 unspecified address - if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - // Check IPv6 unspecified address - if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } - - return false +func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { + return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) } // filterInbound implements filtering logic for incoming packets. @@ -1278,12 +1260,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } case layers.LayerTypeUDP: - // if rule has UDP hook (and if we are here we match this rule) - // we ignore rule.drop and call this hook - if rule.udpHook != nil { - return rule.mgmtId, rule.udpHook(packetData), true - } - if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { return rule.mgmtId, rule.drop, true } @@ -1342,65 +1318,14 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot return sourceMatched } -// AddUDPPacketHook calls hook when UDP packet from given direction matched -// -// Hook function returns flag which indicates should be the matched package dropped or not -func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string { - r := PeerRule{ - id: uuid.New().String(), - ip: ip, - protoLayer: layers.LayerTypeUDP, - dPort: &firewall.Port{Values: []uint16{dPort}}, - ipLayer: layers.LayerTypeIPv6, - udpHook: hook, - } - - if ip.Is4() { - r.ipLayer = layers.LayerTypeIPv4 - } - - m.mutex.Lock() - if in { - // Incoming UDP hooks are stored in allow rules map - if _, ok := m.incomingRules[r.ip]; !ok { - m.incomingRules[r.ip] = make(map[string]PeerRule) - } - m.incomingRules[r.ip][r.id] = r - } else { - if _, ok := m.outgoingRules[r.ip]; !ok { - m.outgoingRules[r.ip] = make(map[string]PeerRule) - } - m.outgoingRules[r.ip][r.id] = r - } - m.mutex.Unlock() - - return r.id +// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove. +func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + common.SetHook(&m.udpHookOut, ip, dPort, hook) } -// RemovePacketHook removes packet hook by given ID -func (m *Manager) RemovePacketHook(hookID string) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Check incoming hooks (stored in allow rules) - for _, arr := range m.incomingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } - } - // Check outgoing hooks - for _, arr := range m.outgoingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } - } - return fmt.Errorf("hook with given id not found") +// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove. +func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + common.SetHook(&m.tcpHookOut, ip, dPort, hook) } // SetLogLevel sets the log level for the firewall manager diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 55a8e723c..39e8efa2c 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" wgdevice "golang.zx2c4.com/wireguard/device" @@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) { } } -func TestAddUDPPacketHook(t *testing.T) { - tests := []struct { - name string - in bool - expDir fw.RuleDirection - ip netip.Addr - dPort uint16 - hook func([]byte) bool - expectedID string - }{ - { - name: "Test Outgoing UDP Packet Hook", - in: false, - expDir: fw.RuleDirectionOUT, - ip: netip.MustParseAddr("10.168.0.1"), - dPort: 8000, - hook: func([]byte) bool { return true }, - }, - { - name: "Test Incoming UDP Packet Hook", - in: true, - expDir: fw.RuleDirectionIN, - ip: netip.MustParseAddr("::1"), - dPort: 9000, - hook: func([]byte) bool { return false }, - }, - } +func TestSetUDPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger, nbiface.DefaultMTU) - require.NoError(t, err) + var called bool + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool { + called = true + return true + }) - manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) + h := manager.udpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(8000), h.Port) + assert.True(t, h.Fn(nil)) + assert.True(t, called) - var addedRule PeerRule - if tt.in { - // Incoming UDP hooks are stored in allow rules map - if len(manager.incomingRules[tt.ip]) != 1 { - t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip])) - return - } - for _, rule := range manager.incomingRules[tt.ip] { - addedRule = rule - } - } else { - if len(manager.outgoingRules[tt.ip]) != 1 { - t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip])) - return - } - for _, rule := range manager.outgoingRules[tt.ip] { - addedRule = rule - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil) + assert.Nil(t, manager.udpHookOut.Load()) +} - if tt.ip.Compare(addedRule.ip) != 0 { - t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) - return - } - if tt.dPort != addedRule.dPort.Values[0] { - t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0]) - return - } - if layers.LayerTypeUDP != addedRule.protoLayer { - t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) - return - } - if addedRule.udpHook == nil { - t.Errorf("expected udpHook to be set") - return - } - }) - } +func TestSetTCPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) + + var called bool + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool { + called = true + return true + }) + + h := manager.tcpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(53), h.Port) + assert.True(t, h.Fn(nil)) + assert.True(t, called) + + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil) + assert.Nil(t, manager.tcpHookOut.Load()) } // TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added @@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) { require.NoError(t, manager.Close(nil)) }() - // Add a UDP packet hook - hookFunc := func(data []byte) bool { return true } - hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc) + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true }) - // Assert the hook is added by finding it in the manager's outgoing rules - found := false - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - found = true - break - } - } - } + require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered") - if !found { - t.Fatalf("The hook was not added properly.") - } - - // Now remove the packet hook - err = manager.RemovePacketHook(hookID) - if err != nil { - t.Fatalf("Failed to remove hook: %s", err) - } - - // Assert the hook is removed by checking it in the manager's outgoing rules - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - t.Fatalf("The hook was not removed properly.") - } - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil) + assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed") } func TestProcessOutgoingHooks(t *testing.T) { @@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) { } hookCalled := false - hookID := manager.AddUDPPacketHook( - false, + manager.SetUDPPacketHook( netip.MustParseAddr("100.10.0.100"), 53, func([]byte) bool { @@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) { return true }, ) - require.NotEmpty(t, hookID) // Create test UDP packet ipv4 := &layers.IPv4{ diff --git a/client/firewall/uspfilter/hooks_filter.go b/client/firewall/uspfilter/hooks_filter.go new file mode 100644 index 000000000..8d3cc0f5c --- /dev/null +++ b/client/firewall/uspfilter/hooks_filter.go @@ -0,0 +1,90 @@ +package uspfilter + +import ( + "encoding/binary" + "net/netip" + "sync/atomic" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/iface/device" +) + +const ( + ipv4HeaderMinLen = 20 + ipv4ProtoOffset = 9 + ipv4FlagsOffset = 6 + ipv4DstOffset = 16 + ipProtoUDP = 17 + ipProtoTCP = 6 + ipv4FragOffMask = 0x1fff + // dstPortOffset is the offset of the destination port within a UDP or TCP header. + dstPortOffset = 2 +) + +// HooksFilter is a minimal packet filter that only handles outbound DNS hooks. +// It is installed on the WireGuard interface when the userspace bind is active +// but a full firewall filter (Manager) is not needed because a native kernel +// firewall (nftables/iptables) handles packet filtering. +type HooksFilter struct { + udpHook atomic.Pointer[common.PacketHook] + tcpHook atomic.Pointer[common.PacketHook] +} + +var _ device.PacketFilter = (*HooksFilter)(nil) + +// FilterOutbound checks outbound packets for DNS hook matches. +// Only IPv4 packets matching the registered hook IP:port are intercepted. +// IPv6 and non-IP packets pass through unconditionally. +func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool { + if len(packetData) < ipv4HeaderMinLen { + return false + } + + // Only process IPv4 packets, let everything else pass through. + if packetData[0]>>4 != 4 { + return false + } + + ihl := int(packetData[0]&0x0f) * 4 + if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 { + return false + } + + // Skip non-first fragments: they don't carry L4 headers. + flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2]) + if flagsAndOffset&ipv4FragOffMask != 0 { + return false + } + + dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4]) + if !ok { + return false + } + + proto := packetData[ipv4ProtoOffset] + dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2]) + + switch proto { + case ipProtoUDP: + return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData) + case ipProtoTCP: + return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData) + default: + return false + } +} + +// FilterInbound allows all inbound packets (native firewall handles filtering). +func (f *HooksFilter) FilterInbound([]byte, int) bool { + return false +} + +// SetUDPPacketHook registers the UDP packet hook. +func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.udpHook, ip, dPort, hook) +} + +// SetTCPPacketHook registers the TCP packet hook. +func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.tcpHook, ip, dPort, hook) +} diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index ffc807f46..f63fe3e45 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { if err != nil { log.Warnf("failed to get interfaces: %v", err) } else { + // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse + // case where an interface comes up between refreshes. for _, intf := range interfaces { m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) } diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 597f892cf..8ed32eb5e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +// TODO: also delegate to nativeFirewall when available for kernel WG mode func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) } +// AddOutputDNAT delegates to the native firewall if available. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return fmt.Errorf("output DNAT not supported without native firewall") + } + return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT delegates to the native firewall if available. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return nil + } + return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.portDNATEnabled.Load() { diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index dbe3a7858..08d68a78e 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -18,9 +18,7 @@ type PeerRule struct { protoLayer gopacket.LayerType sPort *firewall.Port dPort *firewall.Port - drop bool - - udpHook func([]byte) bool + drop bool } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index d9f9f1aa8..657f96fc0 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) { { name: "UDPTraffic_WithHook", setup: func(m *Manager) { - hookFunc := func([]byte) bool { - return true - } - m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { + return true // drop (intercepted by hook) + }) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN) + return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT) }, expectedStages: []PacketStage{ StageReceived, - StageInboundPortDNAT, - StageInbound1to1NAT, - StageConntrack, - StageRouting, - StagePeerACL, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: false, diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54966b50e..9a6bc0670 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff { // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) // for js, the outer websocket layer takes care of tls if tlsEnabled && runtime.GOOS != "js" { @@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - conn, err := grpc.DialContext( - connCtx, - addr, + opts := []grpc.DialOption{ transportOption, WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), @@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone Time: 30 * time.Second, Timeout: 10 * time.Second, }), - ) + } + opts = append(opts, extraOpts...) + + conn, err := grpc.DialContext(connCtx, addr, opts...) if err != nil { return nil, fmt.Errorf("dial context: %w", err) } diff --git a/client/iface/configurer/uapi.go b/client/iface/configurer/uapi.go index f85c7852a..d9bd9bfab 100644 --- a/client/iface/configurer/uapi.go +++ b/client/iface/configurer/uapi.go @@ -5,20 +5,18 @@ package configurer import ( "net" - log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/ipc" ) func openUAPI(deviceName string) (net.Listener, error) { uapiSock, err := ipc.UAPIOpen(deviceName) if err != nil { - log.Errorf("failed to open uapi socket: %v", err) return nil, err } listener, err := ipc.UAPIListen(deviceName, uapiSock) if err != nil { - log.Errorf("failed to listen on uapi socket: %v", err) + _ = uapiSock.Close() return nil, err } diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 1298c609d..e3a96590c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder return wgCfg } +func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer { + return &WGUSPConfigurer{ + device: device, + deviceName: deviceName, + activityRecorder: activityRecorder, + } +} + func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 708f38d26..4357d1916 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -15,14 +15,17 @@ type PacketFilter interface { // FilterInbound filter incoming packets from external sources to host FilterInbound(packetData []byte, size int) bool - // AddUDPPacketHook calls hook when UDP packet from given direction matched - // - // Hook function returns flag which indicates should be the matched package dropped or not. - // Hook function receives raw network packet data as argument. - AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string + // SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one UDP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) - // RemovePacketHook removes hook by ID - RemovePacketHook(hookID string) error + // SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one TCP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) } // FilteredDevice to override Read or Write of packets diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e457657f7..1a92b148f 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) + t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { if cErr := tunIface.Close(); cErr != nil { diff --git a/client/iface/iface.go b/client/iface/iface.go index 9b331d68c..655dd1682 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error // Close closes the tunnel interface func (w *WGIface) Close() error { w.mu.Lock() - defer w.mu.Unlock() var result *multierror.Error @@ -225,7 +224,15 @@ func (w *WGIface) Close() error { result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - if err := w.tun.Close(); err != nil { + // Release w.mu before calling w.tun.Close(): the underlying + // wireguard-go device.Close() waits for its send/receive goroutines + // to drain. Some of those goroutines re-enter WGIface methods that + // take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so + // holding the mutex here would deadlock the shutdown path. + tun := w.tun + w.mu.Unlock() + + if err := tun.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) } diff --git a/client/iface/iface_close_test.go b/client/iface/iface_close_test.go new file mode 100644 index 000000000..171e15d0a --- /dev/null +++ b/client/iface/iface_close_test.go @@ -0,0 +1,113 @@ +//go:build !android + +package iface + +import ( + "errors" + "sync" + "testing" + "time" + + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// fakeTunDevice implements WGTunDevice and lets the test control when +// Close() returns. It mimics the wireguard-go shutdown path, which blocks +// until its goroutines drain. Some of those goroutines (e.g. the packet +// filter DNS hook in client/internal/dns) call back into WGIface, so if +// WGIface.Close() held w.mu across tun.Close() the shutdown would +// deadlock. +type fakeTunDevice struct { + closeStarted chan struct{} + unblockClose chan struct{} +} + +func (f *fakeTunDevice) Create() (device.WGConfigurer, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil } +func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} } +func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU } +func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" } +func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil } +func (f *fakeTunDevice) Device() *wgdevice.Device { return nil } +func (f *fakeTunDevice) GetNet() *netstack.Net { return nil } +func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil } + +func (f *fakeTunDevice) Close() error { + close(f.closeStarted) + <-f.unblockClose + return nil +} + +type fakeProxyFactory struct{} + +func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil } +func (fakeProxyFactory) GetProxyPort() uint16 { return 0 } +func (fakeProxyFactory) Free() error { return nil } + +// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock +// that surfaces as a macOS test-timeout in +// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu +// while waiting for the wireguard-go device goroutines to finish, and +// one of those goroutines (the DNS filter hook) calls back into +// WGIface.GetDevice() which needs the same mutex. The fix is to drop +// the lock before tun.Close() returns control. +func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) { + tun := &fakeTunDevice{ + closeStarted: make(chan struct{}), + unblockClose: make(chan struct{}), + } + w := &WGIface{ + tun: tun, + wgProxyFactory: fakeProxyFactory{}, + } + + closeDone := make(chan error, 1) + go func() { + closeDone <- w.Close() + }() + + select { + case <-tun.closeStarted: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + t.Fatal("tun.Close() was never invoked") + } + + // Simulate the WireGuard read goroutine calling back into WGIface + // via the packet filter's DNS hook. If Close() still held w.mu + // during tun.Close(), this would block until the test timeout. + getDeviceDone := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = w.GetDevice() + close(getDeviceDone) + }() + + select { + case <-getDeviceDone: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + wg.Wait() + t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun") + } + + close(tun.unblockClose) + select { + case <-closeDone: + case <-time.After(2 * time.Second): + t.Fatal("WGIface.Close() never returned after the tun was unblocked") + } +} diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index 566068aa5..5ae98039c 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { return m.recorder } -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string { +// SetUDPPacketHook mocks base method. +func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(string) - return ret0 + m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2) } -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// SetUDPPacketHook indicates an expected call of SetUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2) +} + +// SetTCPPacketHook mocks base method. +func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2) +} + +// SetTCPPacketHook indicates an expected call of SetTCPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2) } // FilterInbound mocks base method. @@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1) } - -// RemovePacketHook mocks base method. -func (m *MockPacketFilter) RemovePacketHook(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemovePacketHook", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemovePacketHook indicates an expected call of RemovePacketHook. -func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) -} diff --git a/client/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go deleted file mode 100644 index 291ab9ab5..000000000 --- a/client/iface/mocks/iface/mocks/filter.go +++ /dev/null @@ -1,87 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockPacketFilter is a mock of PacketFilter interface. -type MockPacketFilter struct { - ctrl *gomock.Controller - recorder *MockPacketFilterMockRecorder -} - -// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter. -type MockPacketFilterMockRecorder struct { - mock *MockPacketFilter -} - -// NewMockPacketFilter creates a new mock instance. -func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter { - mock := &MockPacketFilter{ctrl: ctrl} - mock.recorder = &MockPacketFilterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { - return m.recorder -} - -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) -} - -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) -} - -// FilterInbound mocks base method. -func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterInbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterInbound indicates an expected call of FilterInbound. -func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0) -} - -// FilterOutbound mocks base method. -func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterOutbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterOutbound indicates an expected call of FilterOutbound. -func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0) -} - -// SetNetwork mocks base method. -func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetNetwork", arg0) -} - -// SetNetwork indicates an expected call of SetNetwork. -func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) -} diff --git a/client/iface/udpmux/universal.go b/client/iface/udpmux/universal.go index 43bfedaaa..89a7eefb9 100644 --- a/client/iface/udpmux/universal.go +++ b/client/iface/udpmux/universal.go @@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { } if u.address.Network.Contains(a) { - log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) + log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) } @@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { u.addrCache.Store(addr.String(), isRouted) if isRouted { // Extra log, as the error only shows up with ICE logging enabled - log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) + log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) } } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index bd7adfaef..408ed992f 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -19,6 +19,9 @@ import ( var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() func TestDefaultManager(t *testing.T) { + t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") + networkMap := &mgmProto.NetworkMap{ FirewallRules: []*mgmProto.FirewallRule{ { @@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) { func TestDefaultManagerStateless(t *testing.T) { // stateless currently only in userspace, so we have to disable kernel t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") t.Setenv("NB_DISABLE_CONNTRACK", "true") networkMap := &mgmProto.NetworkMap{ @@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) { // This tests the full ACL manager -> uspfilter integration. func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") networkMap := &mgmProto.NetworkMap{ FirewallRules: []*mgmProto.FirewallRule{ @@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) { // up when they're removed from the network map in a subsequent update. func TestDenyRulesCleanedUpOnRemoval(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) { // one added without leaking. func TestRuleUpdateChangingAction(t *testing.T) { t.Setenv("NB_WG_KERNEL_DISABLED", "true") + t.Setenv(firewall.EnvForceUserspaceFirewall, "true") ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go index 44e98bede..bdfd07430 100644 --- a/client/internal/auth/auth.go +++ b/client/internal/auth/auth.go @@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) { var needsLogin bool err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { - _, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + err := a.doMgmLogin(client, ctx, pubSSHKey) if isLoginNeeded(err) { needsLogin = true return nil @@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err var isAuthError bool err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { - serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey) - if serverKey != nil && isRegistrationNeeded(err) { + err := a.doMgmLogin(client, ctx, pubSSHKey) + if isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey) if err != nil { @@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err // getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - - protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey) + protoFlow, err := client.GetPKCEAuthorizationFlow() if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { log.Warnf("server couldn't find pkce flow, contact admin: %v", err) @@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro config := &PKCEAuthProviderConfig{ Audience: protoConfig.GetAudience(), ClientID: protoConfig.GetClientID(), - ClientSecret: protoConfig.GetClientSecret(), + ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck TokenEndpoint: protoConfig.GetTokenEndpoint(), AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(), Scope: protoConfig.GetScope(), @@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro // getDeviceFlow retrieves device authorization flow configuration and creates a flow instance func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - - protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey) + protoFlow, err := client.GetDeviceAuthorizationFlow() if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { log.Warnf("server couldn't find device flow, contact admin: %v", err) @@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, config := &DeviceAuthProviderConfig{ Audience: protoConfig.GetAudience(), ClientID: protoConfig.GetClientID(), - ClientSecret: protoConfig.GetClientSecret(), + ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck Domain: protoConfig.Domain, TokenEndpoint: protoConfig.GetTokenEndpoint(), DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(), @@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, } // doMgmLogin performs the actual login operation with the management service -func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) { - serverKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err - } - +func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error { sysInfo := system.GetInfo(ctx) a.setSystemInfoFlags(sysInfo) - loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels) - return serverKey, loginResp, err + _, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels) + return err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // Otherwise tries to register with the provided setupKey via command line. func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { - serverPublicKey, err := client.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err - } - validSetupKey, err := uuid.Parse(setupKey) if err != nil && jwtToken == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) @@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe log.Debugf("sending peer registration request to Management Service") info := system.GetInfo(ctx) a.setSystemInfoFlags(info) - loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) + loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) if err != nil { log.Errorf("failed registering peer %v", err) return nil, err diff --git a/client/internal/connect.go b/client/internal/connect.go index 17fc20c42..ac498f719 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -23,12 +23,13 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/updatemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" @@ -43,14 +44,19 @@ import ( "github.com/netbirdio/netbird/version" ) -type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - doInitialAutoUpdate bool +// androidRunOverride is set on Android to inject mobile dependencies +// when using embed.Client (which calls Run() with empty MobileDependency). +var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error - engine *Engine - engineMutex sync.Mutex +type ConnectClient struct { + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status + + engine *Engine + engineMutex sync.Mutex + clientMetrics *metrics.ClientMetrics + updateManager *updater.Manager persistSyncResponse bool } @@ -59,19 +65,24 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - doInitialAutoUpdate: doInitalAutoUpdate, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + engineMutex: sync.Mutex{}, } } +func (c *ConnectClient) SetUpdateManager(um *updater.Manager) { + c.updateManager = um +} + // Run with main logic. func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error { + if androidRunOverride != nil { + return androidRunOverride(c, runningChan, logPath) + } return c.run(MobileDependency{}, runningChan, logPath) } @@ -83,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid( dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, stateFilePath string, + cacheDir string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -92,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, StateFilePath: stateFilePath, + TempDir: cacheDir, } return c.run(mobileDependency, nil, "") } @@ -100,6 +113,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -109,6 +123,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") @@ -131,10 +146,34 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } }() + // Stop metrics push on exit + defer func() { + if c.clientMetrics != nil { + c.clientMetrics.StopPush() + } + }() + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) nbnet.Init() + // Initialize metrics once at startup (always active for debug bundles) + if c.clientMetrics == nil { + agentInfo := metrics.AgentInfo{ + DeploymentType: metrics.DeploymentTypeUnknown, + Version: version.NetbirdVersion(), + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + c.clientMetrics = metrics.NewClientMetrics(agentInfo) + log.Debugf("initialized client metrics") + + // Start metrics push if enabled (uses daemon context, persists across engine restarts) + if metrics.IsMetricsPushEnabled() { + c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv()) + } + } + backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -187,14 +226,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan stateManager := statemanager.New(path) stateManager.RegisterState(&sshconfig.ShutdownState{}) - updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) - if err == nil { - updateManager.CheckUpdateSuccess(c.ctx) + if c.updateManager != nil { + c.updateManager.CheckUpdateSuccess(c.ctx) + } - inst := installer.New() - if err := inst.CleanUpInstallerFiles(); err != nil { - log.Errorf("failed to clean up temporary installer file: %v", err) - } + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) } defer c.statusRecorder.ClientStop() @@ -222,6 +260,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) mgmClient.SetConnStateListener(mgmNotifier) + // Update metrics with actual deployment type after connection + deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL()) + agentInfo := metrics.AgentInfo{ + DeploymentType: deploymentType, + Version: version.NetbirdVersion(), + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String()) + log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) defer func() { if err = mgmClient.Close(); err != nil { @@ -230,8 +278,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan }() // connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config + loginStarted := time.Now() loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config) if err != nil { + c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false) log.Debug(err) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { state.Set(StatusNeedsLogin) @@ -240,6 +290,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } return wrapErr(err) } + c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true) c.statusRecorder.MarkManagementConnected() localPeerState := peer.LocalPeerState{ @@ -289,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan log.Error(err) return wrapErr(err) } + engineConfig.TempDir = mobileDependency.TempDir relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) c.statusRecorder.SetRelayMgr(relayManager) @@ -308,7 +360,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) + engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmClient, + RelayManager: relayManager, + StatusRecorder: c.statusRecorder, + Checks: checks, + StateManager: stateManager, + UpdateManager: c.updateManager, + ClientMetrics: c.clientMetrics, + }, mobileDependency) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine c.engineMutex.Unlock() @@ -318,21 +379,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { - // AutoUpdate will be true when the user click on "Connect" menu on the UI - if c.doInitialAutoUpdate { - log.Infof("start engine by ui, run auto-update check") - c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) - c.doInitialAutoUpdate = false - } - } - log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) if runningChan != nil { - close(runningChan) - runningChan = nil + select { + case <-runningChan: + default: + close(runningChan) + } } <-engineCtx.Done() @@ -567,12 +622,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - - serverPublicKey, err := client.GetServerPublicKey() - if err != nil { - return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err) - } - sysInfo := system.GetInfo(ctx) sysInfo.SetFlags( config.RosenpassEnabled, @@ -591,12 +640,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.EnableSSHRemotePortForwarding, config.DisableSSHAuth, ) - loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) - if err != nil { - return nil, err - } - - return loginResp, nil + return client.Login(sysInfo, pubSSHKey, config.DNSLabels) } func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier { diff --git a/client/internal/connect_android_default.go b/client/internal/connect_android_default.go new file mode 100644 index 000000000..190341c4a --- /dev/null +++ b/client/internal/connect_android_default.go @@ -0,0 +1,73 @@ +//go:build android + +package internal + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android. +// It returns an empty interface list, which means ICE P2P candidates won't be +// discovered — connections will fall back to relay. Applications that need P2P +// should provide a real implementation via runOnAndroidEmbed that uses +// Android's ConnectivityManager to enumerate network interfaces. +type noopIFaceDiscover struct{} + +func (noopIFaceDiscover) IFaces() (string, error) { + // Return empty JSON array — no local interfaces advertised for ICE. + // This is intentional: without Android's ConnectivityManager, we cannot + // reliably enumerate interfaces (netlink is restricted on Android 11+). + // Relay connections still work; only P2P hole-punching is disabled. + return "[]", nil +} + +// noopNetworkChangeListener is a stub for embed.Client on Android. +// Network change events are ignored since the embed client manages its own +// reconnection logic via the engine's built-in retry mechanism. +type noopNetworkChangeListener struct{} + +func (noopNetworkChangeListener) OnNetworkChanged(string) { + // No-op: embed.Client relies on the engine's internal reconnection + // logic rather than OS-level network change notifications. +} + +func (noopNetworkChangeListener) SetInterfaceIP(string) { + // No-op: in netstack mode, the overlay IP is managed by the userspace + // network stack, not by OS-level interface configuration. +} + +// noopDnsReadyListener is a stub for embed.Client on Android. +// DNS readiness notifications are not needed in netstack/embed mode +// since system DNS is disabled and DNS resolution happens externally. +type noopDnsReadyListener struct{} + +func (noopDnsReadyListener) OnReady() { + // No-op: embed.Client does not need DNS readiness notifications. + // System DNS is disabled in netstack mode. +} + +var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{} +var _ listener.NetworkChangeListener = noopNetworkChangeListener{} +var _ dns.ReadyListener = noopDnsReadyListener{} + +func init() { + // Wire up the default override so embed.Client.Start() works on Android + // with netstack mode. Provides complete no-op stubs for all mobile + // dependencies so the engine's existing Android code paths work unchanged. + // Applications that need P2P ICE or real DNS should replace this by + // setting androidRunOverride before calling Start(). + androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error { + return c.runOnAndroidEmbed( + noopIFaceDiscover{}, + noopNetworkChangeListener{}, + []netip.AddrPort{}, + noopDnsReadyListener{}, + runningChan, + logPath, + ) + } +} diff --git a/client/internal/connect_android_embed.go b/client/internal/connect_android_embed.go new file mode 100644 index 000000000..18f72e841 --- /dev/null +++ b/client/internal/connect_android_embed.go @@ -0,0 +1,32 @@ +//go:build android + +package internal + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan +// so embed.Client.Start() can detect when the engine is ready. +// It provides complete MobileDependency so the engine's existing +// Android code paths work unchanged. +func (c *ConnectClient) runOnAndroidEmbed( + iFaceDiscover stdnet.ExternalIFaceDiscover, + networkChangeListener listener.NetworkChangeListener, + dnsAddresses []netip.AddrPort, + dnsReadyListener dns.ReadyListener, + runningChan chan struct{}, + logPath string, +) error { + mobileDependency := MobileDependency{ + IFaceDiscover: iFaceDiscover, + NetworkChangeListener: networkChangeListener, + HostDNSAddresses: dnsAddresses, + DnsReadyListener: dnsReadyListener, + } + return c.run(mobileDependency, runningChan, logPath) +} diff --git a/client/internal/daemonaddr/resolve.go b/client/internal/daemonaddr/resolve.go new file mode 100644 index 000000000..b445696ab --- /dev/null +++ b/client/internal/daemonaddr/resolve.go @@ -0,0 +1,60 @@ +//go:build !windows && !ios && !android + +package daemonaddr + +import ( + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" +) + +var scanDir = "/var/run/netbird" + +// setScanDir overrides the scan directory (used by tests). +func setScanDir(dir string) { + scanDir = dir +} + +// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not, +// scans /var/run/netbird/ for a single .sock file to use instead. This handles the +// mismatch between the netbird@.service template (which places the socket under +// /var/run/netbird/.sock) and the CLI default (/var/run/netbird.sock). +func ResolveUnixDaemonAddr(addr string) string { + if !strings.HasPrefix(addr, "unix://") { + return addr + } + + sockPath := strings.TrimPrefix(addr, "unix://") + if _, err := os.Stat(sockPath); err == nil { + return addr + } + + entries, err := os.ReadDir(scanDir) + if err != nil { + return addr + } + + var found []string + for _, e := range entries { + if e.IsDir() { + continue + } + if strings.HasSuffix(e.Name(), ".sock") { + found = append(found, filepath.Join(scanDir, e.Name())) + } + } + + switch len(found) { + case 1: + resolved := "unix://" + found[0] + log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved) + return resolved + case 0: + return addr + default: + log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir) + return addr + } +} diff --git a/client/internal/daemonaddr/resolve_stub.go b/client/internal/daemonaddr/resolve_stub.go new file mode 100644 index 000000000..080b7171a --- /dev/null +++ b/client/internal/daemonaddr/resolve_stub.go @@ -0,0 +1,8 @@ +//go:build windows || ios || android + +package daemonaddr + +// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets. +func ResolveUnixDaemonAddr(addr string) string { + return addr +} diff --git a/client/internal/daemonaddr/resolve_test.go b/client/internal/daemonaddr/resolve_test.go new file mode 100644 index 000000000..3df67708a --- /dev/null +++ b/client/internal/daemonaddr/resolve_test.go @@ -0,0 +1,121 @@ +//go:build !windows && !ios && !android + +package daemonaddr + +import ( + "os" + "path/filepath" + "testing" +) + +// createSockFile creates a regular file with a .sock extension. +// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is +// sufficient and avoids Unix socket path-length limits on macOS. +func createSockFile(t *testing.T, path string) { + t.Helper() + if err := os.WriteFile(path, nil, 0o600); err != nil { + t.Fatalf("failed to create test sock file at %s: %v", path, err) + } +} + +func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) { + tmp := t.TempDir() + sock := filepath.Join(tmp, "netbird.sock") + createSockFile(t, sock) + + addr := "unix://" + sock + got := ResolveUnixDaemonAddr(addr) + if got != addr { + t.Errorf("expected %s, got %s", addr, got) + } +} + +func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) { + tmp := t.TempDir() + + // Default socket does not exist + defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock") + + // Create a scan dir with one socket + sd := filepath.Join(tmp, "netbird") + if err := os.MkdirAll(sd, 0o755); err != nil { + t.Fatal(err) + } + instanceSock := filepath.Join(sd, "main.sock") + createSockFile(t, instanceSock) + + origScanDir := scanDir + setScanDir(sd) + t.Cleanup(func() { setScanDir(origScanDir) }) + + got := ResolveUnixDaemonAddr(defaultAddr) + expected := "unix://" + instanceSock + if got != expected { + t.Errorf("expected %s, got %s", expected, got) + } +} + +func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) { + tmp := t.TempDir() + + defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock") + + sd := filepath.Join(tmp, "netbird") + if err := os.MkdirAll(sd, 0o755); err != nil { + t.Fatal(err) + } + createSockFile(t, filepath.Join(sd, "main.sock")) + createSockFile(t, filepath.Join(sd, "other.sock")) + + origScanDir := scanDir + setScanDir(sd) + t.Cleanup(func() { setScanDir(origScanDir) }) + + got := ResolveUnixDaemonAddr(defaultAddr) + if got != defaultAddr { + t.Errorf("expected original %s, got %s", defaultAddr, got) + } +} + +func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) { + tmp := t.TempDir() + + defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock") + + sd := filepath.Join(tmp, "netbird") + if err := os.MkdirAll(sd, 0o755); err != nil { + t.Fatal(err) + } + + origScanDir := scanDir + setScanDir(sd) + t.Cleanup(func() { setScanDir(origScanDir) }) + + got := ResolveUnixDaemonAddr(defaultAddr) + if got != defaultAddr { + t.Errorf("expected original %s, got %s", defaultAddr, got) + } +} + +func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) { + addr := "tcp://127.0.0.1:41731" + got := ResolveUnixDaemonAddr(addr) + if got != addr { + t.Errorf("expected %s, got %s", addr, got) + } +} + +func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) { + tmp := t.TempDir() + + defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock") + + origScanDir := scanDir + setScanDir(filepath.Join(tmp, "nonexistent")) + t.Cleanup(func() { setScanDir(origScanDir) }) + + got := ResolveUnixDaemonAddr(defaultAddr) + if got != defaultAddr { + t.Errorf("expected original %s, got %s", defaultAddr, got) + } +} diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 0f8243e7a..bddb9a69e 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -16,7 +16,6 @@ import ( "path/filepath" "runtime" "runtime/pprof" - "slices" "sort" "strings" "time" @@ -25,13 +24,12 @@ import ( "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" - "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/version" ) const readmeContent = `Netbird debug bundle @@ -53,6 +51,8 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states for the active profile. +service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists. +metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -219,6 +219,11 @@ const ( darwinStdoutLogPath = "/var/log/netbird.err.log" ) +// MetricsExporter is an interface for exporting metrics +type MetricsExporter interface { + Export(w io.Writer) error +} + type BundleGenerator struct { anonymizer *anonymize.Anonymizer @@ -227,8 +232,10 @@ type BundleGenerator struct { statusRecorder *peer.Status syncResponse *mgmProto.SyncResponse logPath string + tempDir string cpuProfile []byte refreshStatus func() // Optional callback to refresh status before bundle generation + clientMetrics MetricsExporter anonymize bool includeSystemInfo bool @@ -248,8 +255,10 @@ type GeneratorDependencies struct { StatusRecorder *peer.Status SyncResponse *mgmProto.SyncResponse LogPath string + TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. CPUProfile []byte RefreshStatus func() // Optional callback to refresh status before bundle generation + ClientMetrics MetricsExporter } func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { @@ -266,8 +275,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen statusRecorder: deps.StatusRecorder, syncResponse: deps.SyncResponse, logPath: deps.LogPath, + tempDir: deps.TempDir, cpuProfile: deps.CPUProfile, refreshStatus: deps.RefreshStatus, + clientMetrics: deps.ClientMetrics, anonymize: cfg.Anonymize, includeSystemInfo: cfg.IncludeSystemInfo, @@ -277,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen // Generate creates a debug bundle and returns the location. func (g *BundleGenerator) Generate() (resp string, err error) { - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") + bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip") if err != nil { return "", fmt.Errorf("create zip file: %w", err) } @@ -351,19 +362,20 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add corrupted state files to debug bundle: %v", err) } + if err := g.addServiceParams(); err != nil { + log.Errorf("failed to add service params to debug bundle: %v", err) + } + + if err := g.addMetrics(); err != nil { + log.Errorf("failed to add metrics to debug bundle: %v", err) + } + if err := g.addWgShow(); err != nil { log.Errorf("failed to add wg show output: %v", err) } - if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { - if err := g.addLogfile(); err != nil { - log.Errorf("failed to add log file to debug bundle: %v", err) - if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs as fallback: %v", err) - } - } - } else if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs: %v", err) + if err := g.addPlatformLog(); err != nil { + log.Errorf("failed to add logs to debug bundle: %v", err) } if err := g.addUpdateLogs(); err != nil { @@ -418,7 +430,10 @@ func (g *BundleGenerator) addStatus() error { fullStatus := g.statusRecorder.GetFullStatus() protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus) protoFullStatus.Events = g.statusRecorder.GetEventHistory() - overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName) + overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{ + Anonymize: g.anonymize, + ProfileName: profName, + }) statusOutput := overview.FullDetailSummary() statusReader := strings.NewReader(statusOutput) @@ -473,6 +488,90 @@ func (g *BundleGenerator) addConfig() error { return nil } +const ( + serviceParamsFile = "service.json" + serviceParamsBundle = "service_params.json" + maskedValue = "***" + envVarPrefix = "NB_" + jsonKeyManagementURL = "management_url" + jsonKeyServiceEnv = "service_env_vars" +) + +var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"} + +// addServiceParams reads the service.json file and adds a sanitized version to the bundle. +// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized. +func (g *BundleGenerator) addServiceParams() error { + path := filepath.Join(configs.StateDir, serviceParamsFile) + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read service params: %w", err) + } + + var params map[string]any + if err := json.Unmarshal(data, ¶ms); err != nil { + return fmt.Errorf("parse service params: %w", err) + } + + if g.anonymize { + if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" { + params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL) + } + } + + g.sanitizeServiceEnvVars(params) + + sanitizedData, err := json.MarshalIndent(params, "", " ") + if err != nil { + return fmt.Errorf("marshal sanitized service params: %w", err) + } + + if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil { + return fmt.Errorf("add service params to zip: %w", err) + } + + return nil +} + +// sanitizeServiceEnvVars masks or anonymizes env var values in service params. +// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked. +// Other NB_ var values are passed through the anonymizer when anonymization is enabled. +func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) { + envVars, ok := params[jsonKeyServiceEnv].(map[string]any) + if !ok { + return + } + + sanitized := make(map[string]any, len(envVars)) + for k, v := range envVars { + val, _ := v.(string) + switch { + case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k): + sanitized[k] = maskedValue + case g.anonymize: + sanitized[k] = g.anonymizer.AnonymizeString(val) + default: + sanitized[k] = val + } + } + params[jsonKeyServiceEnv] = sanitized +} + +// isSensitiveEnvVar returns true for env var names that may contain secrets. +func isSensitiveEnvVar(key string) bool { + lower := strings.ToLower(key) + for _, s := range sensitiveEnvSubstrings { + if strings.Contains(lower, s) { + return true + } + } + return false +} + func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") @@ -744,6 +843,30 @@ func (g *BundleGenerator) addCorruptedStateFiles() error { return nil } +func (g *BundleGenerator) addMetrics() error { + if g.clientMetrics == nil { + log.Debugf("skipping metrics in debug bundle: no metrics collector") + return nil + } + + var buf bytes.Buffer + if err := g.clientMetrics.Export(&buf); err != nil { + return fmt.Errorf("export metrics: %w", err) + } + + if buf.Len() == 0 { + log.Debugf("skipping metrics.txt in debug bundle: no metrics data") + return nil + } + + if err := g.addFileToZip(&buf, "metrics.txt"); err != nil { + return fmt.Errorf("add metrics file to zip: %w", err) + } + + log.Debugf("added metrics to debug bundle") + return nil +} + func (g *BundleGenerator) addLogfile() error { if g.logPath == "" { log.Debugf("skipping empty log file in debug bundle") diff --git a/client/internal/debug/debug_android.go b/client/internal/debug/debug_android.go new file mode 100644 index 000000000..a4e2b3e98 --- /dev/null +++ b/client/internal/debug/debug_android.go @@ -0,0 +1,41 @@ +//go:build android + +package debug + +import ( + "fmt" + "io" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +func (g *BundleGenerator) addPlatformLog() error { + cmd := exec.Command("/system/bin/logcat", "-d") + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("logcat stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start logcat: %w", err) + } + + var logReader io.Reader = stdout + if g.anonymize { + var pw *io.PipeWriter + logReader, pw = io.Pipe() + go anonymizeLog(stdout, pw, g.anonymizer) + } + + if err := g.addFileToZip(logReader, "logcat.txt"); err != nil { + return fmt.Errorf("add logcat to zip: %w", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("wait logcat: %w", err) + } + + log.Debug("added logcat output to debug bundle") + return nil +} diff --git a/client/internal/debug/debug_nonandroid.go b/client/internal/debug/debug_nonandroid.go new file mode 100644 index 000000000..117238dec --- /dev/null +++ b/client/internal/debug/debug_nonandroid.go @@ -0,0 +1,25 @@ +//go:build !android + +package debug + +import ( + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +func (g *BundleGenerator) addPlatformLog() error { + if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { + if err := g.addLogfile(); err != nil { + log.Errorf("failed to add log file to debug bundle: %v", err) + if err := g.trySystemdLogFallback(); err != nil { + return err + } + } + } else if err := g.trySystemdLogFallback(); err != nil { + return err + } + return nil +} diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 59837c328..6b5bb911c 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -1,8 +1,12 @@ package debug import ( + "archive/zip" + "bytes" "encoding/json" "net" + "os" + "path/filepath" "strings" "testing" @@ -10,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) { } } +func TestIsSensitiveEnvVar(t *testing.T) { + tests := []struct { + key string + sensitive bool + }{ + {"NB_SETUP_KEY", true}, + {"NB_API_TOKEN", true}, + {"NB_CLIENT_SECRET", true}, + {"NB_PASSWORD", true}, + {"NB_CREDENTIAL", true}, + {"NB_LOG_LEVEL", false}, + {"NB_MANAGEMENT_URL", false}, + {"NB_HOSTNAME", false}, + {"HOME", false}, + {"PATH", false}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key)) + }) + } +} + +func TestSanitizeServiceEnvVars(t *testing.T) { + tests := []struct { + name string + anonymize bool + input map[string]any + check func(t *testing.T, params map[string]any) + }{ + { + name: "no env vars key", + anonymize: false, + input: map[string]any{"management_url": "https://mgmt.example.com"}, + check: func(t *testing.T, params map[string]any) { + t.Helper() + assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched") + _, ok := params[jsonKeyServiceEnv] + assert.False(t, ok, "service_env_vars should not be added") + }, + }, + { + name: "non-NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "HOME": "/root", + "PATH": "/usr/bin", + "NB_LOG_LEVEL": "debug", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked") + assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked") + assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "sensitive NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_SETUP_KEY": "abc123", + "NB_API_TOKEN": "tok_xyz", + "NB_LOG_LEVEL": "info", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked") + assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked") + assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "safe NB vars anonymized when anonymize is true", + anonymize: true, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_MANAGEMENT_URL": "https://mgmt.example.com:443", + "NB_LOG_LEVEL": "debug", + "NB_SETUP_KEY": "secret", + "SOME_OTHER": "val", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + // Safe NB_ values should be anonymized (not the original, not masked) + mgmtVal := env["NB_MANAGEMENT_URL"].(string) + assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized") + assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked") + + logVal := env["NB_LOG_LEVEL"].(string) + assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked") + + // Sensitive and non-NB_ still masked + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"]) + assert.Equal(t, maskedValue, env["SOME_OTHER"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + g := &BundleGenerator{ + anonymize: tt.anonymize, + anonymizer: anonymizer, + } + g.sanitizeServiceEnvVars(tt.input) + tt.check(t, tt.input) + }) + } +} + +func TestAddServiceParams(t *testing.T) { + t.Run("missing service.json returns nil", func(t *testing.T) { + g := &BundleGenerator{ + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + } + + origStateDir := configs.StateDir + configs.StateDir = t.TempDir() + t.Cleanup(func() { configs.StateDir = origStateDir }) + + err := g.addServiceParams() + assert.NoError(t, err) + }) + + t.Run("management_url anonymized when anonymize is true", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + jsonKeyServiceEnv: map[string]any{ + "NB_LOG_LEVEL": "trace", + }, + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: true, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + require.Len(t, zr.File, 1) + assert.Equal(t, serviceParamsBundle, zr.File[0].Name) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + mgmt := result[jsonKeyManagementURL].(string) + assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized") + assert.NotEmpty(t, mgmt) + + env := result[jsonKeyServiceEnv].(map[string]any) + assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked") + }) + + t.Run("management_url preserved when anonymize is false", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: false, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved") + }) +} + // Helper function to check if IP is in CGNAT range func isInCGNATRange(ip net.IP) bool { cgnat := net.IPNet{ diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 06a2056b1..6fbdedc59 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { return nil } w.response = m + if m.MsgHdr.Truncated { + w.SetMeta("truncated", "true") + } return w.ResponseWriter.WriteMsg(m) } @@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { startTime := time.Now() requestID := resutil.GenerateRequestID() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": requestID, "dns_id": fmt.Sprintf("%04x", r.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) question := r.Question[0] qname := strings.ToLower(question.Name) @@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q meta += " " + k + "=" + v } - logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s", + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s", qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer), - meta, time.Since(startTime)) + cw.response.Len(), meta, time.Since(startTime)) } func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 9b7a7b52b..4a8cf8cec 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr } } - log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains) + log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains)) return ruleIndex, nil } diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b374bcc6a..a67a23945 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability() {} +func (d *Resolver) ProbeAvailability(context.Context) {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 73f70035f..2c6b7dbc3 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }) } -// TestLocalResolver_Stop tests cleanup on Stop +// TestLocalResolver_Stop tests cleanup on GracefullyStop func TestLocalResolver_Stop(t *testing.T) { - t.Run("Stop clears all state", func(t *testing.T) { + t.Run("GracefullyStop clears all state", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) { assert.False(t, resolver.isInManagedZone("host.example.com.")) }) - t.Run("Stop is safe to call multiple times", func(t *testing.T) { + t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) { resolver.Stop() }) - t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) { resolver := NewResolver() lookupStarted := make(chan struct{}) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index d01be0c2c..314af51d9 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve } } - if serverDomains.Flow != "" { - domains = append(domains, serverDomains.Flow) - } + // Flow receiver domain is intentionally excluded from caching. + // Cloud providers may rotate the IP behind this domain; a stale cached record + // causes TLS certificate verification failures on reconnect. for _, stun := range serverDomains.Stuns { if stun != "" { diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 99d289871..9e8a746f3 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { } assert.Len(t, resolver.GetCachedDomains(), 3) - // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) + // Update with partial ServerDomains (only flow domain - flow is intentionally excluded from + // caching to prevent TLS failures from stale records, so all existing domains are preserved) partialDomains := dnsconfig.ServerDomains{ Flow: "github.com", } @@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { t.Skipf("Skipping test due to DNS resolution failure: %v", err) } - assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") + assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided") finalDomains := resolver.GetCachedDomains() - assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") + assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved") domainStrings := make([]string, len(finalDomains)) for i, d := range finalDomains { @@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { assert.Contains(t, domainStrings, "example.org") assert.Contains(t, domainStrings, "google.com") assert.Contains(t, domainStrings, "cloudflare.com") - assert.Contains(t, domainStrings, "github.com") + assert.NotContains(t, domainStrings, "github.com") } diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index fe160e20a..548b1f54f 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -85,6 +85,16 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } +// SetRouteChecker mock implementation of SetRouteChecker from Server interface +func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { + // Mock implementation - no-op +} + +// SetFirewall mock implementation of SetFirewall from Server interface +func (m *MockServer) SetFirewall(Firewall) { + // Mock implementation - no-op +} + // BeginBatch mock implementation of BeginBatch from Server interface func (m *MockServer) BeginBatch() { // Mock implementation - no-op diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go index edc65a5d9..287cf28b0 100644 --- a/client/internal/dns/response_writer.go +++ b/client/internal/dns/response_writer.go @@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) { // After a call to Hijack(), the DNS package will not do anything with the connection. func (r *responseWriter) Hijack() { } + +// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging. +func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr { + var srcIP net.IP + if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil { + srcIP = ipv4.(*layers.IPv4).SrcIP + } else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil { + srcIP = ipv6.(*layers.IPv6).SrcIP + } + + var srcPort int + if udp := packet.Layer(layers.LayerTypeUDP); udp != nil { + srcPort = int(udp.(*layers.UDP).SrcPort) + } + + if srcIP == nil { + return nil + } + return &net.UDPAddr{IP: srcIP, Port: srcPort} +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 179517bbd..f7865047b 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -57,6 +57,8 @@ type Server interface { ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error + SetRouteChecker(func(netip.Addr) bool) + SetFirewall(Firewall) } type nsGroupsByDomain struct { @@ -104,12 +106,17 @@ type DefaultServer struct { statusRecorder *peer.Status stateManager *statemanager.Manager + routeMatch func(netip.Addr) bool + + probeMu sync.Mutex + probeCancel context.CancelFunc + probeWg sync.WaitGroup } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability() + ProbeAvailability(context.Context) ID() types.HandlerID } @@ -145,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default if config.WgInterface.IsUserspaceBind() { dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(config.WgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort, nil) } server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) @@ -180,11 +187,16 @@ func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, + hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { + log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager + ds.hostsDNSHolder.set(hostsDnsList) + ds.permanent = true + ds.addHostRootZone() return ds } @@ -225,6 +237,14 @@ func newDefaultServer( return defaultServer } +// SetRouteChecker sets the function used by upstream resolvers to determine +// whether an IP is routed through the tunnel. +func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { + s.mux.Lock() + defer s.mux.Unlock() + s.routeMatch = f +} + // RegisterHandler registers a handler for the given domains with the given priority. // Any previously registered handler for the same domain and priority will be replaced. func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -360,9 +380,26 @@ func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } +// SetFirewall sets the firewall used for DNS port DNAT rules. +// This must be called before Initialize when using the listener-based service, +// because the firewall is typically not available at construction time. +func (s *DefaultServer) SetFirewall(fw Firewall) { + if svc, ok := s.service.(*serviceViaListener); ok { + svc.listenerFlagLock.Lock() + svc.firewall = fw + svc.listenerFlagLock.Unlock() + } +} + // Stop stops the server func (s *DefaultServer) Stop() { + s.probeMu.Lock() + if s.probeCancel != nil { + s.probeCancel() + } s.ctxCancel() + s.probeMu.Unlock() + s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -375,8 +412,12 @@ func (s *DefaultServer) Stop() { maps.Clear(s.extraDomains) } -func (s *DefaultServer) disableDNS() error { - defer s.service.Stop() +func (s *DefaultServer) disableDNS() (retErr error) { + defer func() { + if err := s.service.Stop(); err != nil { + retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err)) + } + }() if s.isUsingNoopHostManager() { return nil @@ -479,7 +520,8 @@ func (s *DefaultServer) SearchDomains() []string { } // ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds +// and deactivates the group if no server responds. +// If a previous probe is still running, it will be cancelled before starting a new one. func (s *DefaultServer) ProbeAvailability() { if val := os.Getenv(envSkipDNSProbe); val != "" { skipProbe, err := strconv.ParseBool(val) @@ -492,15 +534,52 @@ func (s *DefaultServer) ProbeAvailability() { } } - var wg sync.WaitGroup - for _, mux := range s.dnsMuxMap { - wg.Add(1) - go func(mux handlerWithStop) { - defer wg.Done() - mux.ProbeAvailability() - }(mux.handler) + s.probeMu.Lock() + + // don't start probes on a stopped server + if s.ctx.Err() != nil { + s.probeMu.Unlock() + return } + + // cancel any running probe + if s.probeCancel != nil { + s.probeCancel() + s.probeCancel = nil + } + + // wait for the previous probe goroutines to finish while holding + // the mutex so no other caller can start a new probe concurrently + s.probeWg.Wait() + + // start a new probe + probeCtx, probeCancel := context.WithCancel(s.ctx) + s.probeCancel = probeCancel + + s.probeWg.Add(1) + defer s.probeWg.Done() + + // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. + s.mux.Lock() + handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) + for _, mux := range s.dnsMuxMap { + handlers = append(handlers, mux.handler) + } + s.mux.Unlock() + + var wg sync.WaitGroup + for _, handler := range handlers { + wg.Add(1) + go func(h handlerWithStop) { + defer wg.Done() + h.ProbeAvailability(probeCtx) + }(handler) + } + + s.probeMu.Unlock() + wg.Wait() + probeCancel() } func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { @@ -695,6 +774,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } + handler.routeMatch = s.routeMatch for _, ns := range originalNameservers { if ns == config.ServerIP { @@ -804,6 +884,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai if err != nil { return nil, fmt.Errorf("create upstream resolver: %v", err) } + handler.routeMatch = s.routeMatch for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { @@ -988,6 +1069,7 @@ func (s *DefaultServer) addHostRootZone() { log.Errorf("unable to create a new upstream resolver, error: %v", err) return } + handler.routeMatch = s.routeMatch handler.upstreamServers = maps.Keys(hostDNSServers) handler.deactivate = func(error) {} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3606d48b9..f77f6e898 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes() - packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - packetfilter.EXPECT().RemovePacketHook(gomock.Any()) + packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() if err := wgIface.SetFilter(packetfilter); err != nil { t.Errorf("set packet filter: %v", err) @@ -1065,13 +1065,13 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} func (m *mockService) Listen() error { return nil } -func (m *mockService) Stop() {} +func (m *mockService) Stop() error { return nil } func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") } func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RegisterMux(string, dns.Handler) {} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 6a76c53e3..1c6ce7849 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -4,15 +4,25 @@ import ( "net/netip" "github.com/miekg/dns" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" ) const ( DefaultPort = 53 ) +// Firewall provides DNAT capabilities for DNS port redirection. +// This is used when the DNS server cannot bind port 53 directly +// and needs firewall rules to redirect traffic. +type Firewall interface { + AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error +} + type service interface { Listen() error - Stop() + Stop() error RegisterMux(domain string, handler dns.Handler) DeregisterMux(key string) RuntimePort() int diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 806559444..4e09f1b7f 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -6,12 +6,17 @@ import ( "net" "net/netip" "runtime" + "strconv" "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ) @@ -30,25 +35,33 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server + tcpServer *dns.Server listenIP netip.Addr listenPort uint16 listenerIsRunning bool listenerFlagLock sync.Mutex ebpfService ebpfMgr.Manager + firewall Firewall + tcpDNATConfigured bool } -func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { +func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener { mux := dns.NewServeMux() s := &serviceViaListener{ wgInterface: wgIface, dnsMux: mux, customAddr: customAddr, + firewall: fw, server: &dns.Server{ Net: "udp", Handler: mux, UDPSize: 65535, }, + tcpServer: &dns.Server{ + Net: "tcp", + Handler: mux, + }, } return s @@ -69,43 +82,86 @@ func (s *serviceViaListener) Listen() error { return fmt.Errorf("eval listen address: %w", err) } s.listenIP = s.listenIP.Unmap() - s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) - log.Debugf("starting dns on %s", s.server.Addr) - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) + addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) + s.server.Addr = addr + s.tcpServer.Addr = addr - err := s.server.ListenAndServe() - if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err) + log.Debugf("starting dns on %s (UDP + TCP)", addr) + s.listenerIsRunning = true + + go func() { + if err := s.server.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err) + } + + s.listenerFlagLock.Lock() + unexpected := s.listenerIsRunning + s.listenerIsRunning = false + s.listenerFlagLock.Unlock() + + if unexpected { + if err := s.tcpServer.Shutdown(); err != nil { + log.Debugf("failed to shutdown DNS TCP server: %v", err) + } } }() + go func() { + if err := s.tcpServer.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err) + } + }() + + // When eBPF redirects UDP port 53 to our listen port, TCP still needs + // a DNAT rule because eBPF only handles UDP. + if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort { + if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err) + } else { + s.tcpDNATConfigured = true + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort) + } + } + return nil } -func (s *serviceViaListener) Stop() { +func (s *serviceViaListener) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil } + s.listenerIsRunning = false ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := s.server.ShutdownContext(ctx) - if err != nil { - log.Errorf("stopping dns server listener returned an error: %v", err) + var merr *multierror.Error + + if err := s.server.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err)) + } + + if err := s.tcpServer.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err)) + } + + if s.tcpDNATConfigured && s.firewall != nil { + if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + s.tcpDNATConfigured = false } if s.ebpfService != nil { - err = s.ebpfService.FreeDNSFwd() - if err != nil { - log.Errorf("stopping traffic forwarder returned an error: %v", err) + if err := s.ebpfService.FreeDNSFwd(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err)) } } + + return nberrors.FormatErrorOrNil(merr) } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { @@ -132,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } -func (s *serviceViaListener) setListenerStatus(running bool) { - s.listenerFlagLock.Lock() - defer s.listenerFlagLock.Unlock() - - s.listenerIsRunning = running -} // evalListenAddress figure out the listen address for the DNS server // first check the 53 port availability on WG interface or lo, if not success @@ -186,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { } func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { - addrString := fmt.Sprintf("%s:%d", ip, port) - udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) - probeListener, err := net.ListenUDP("udp", udpAddr) + addrPort := netip.AddrPortFrom(ip, uint16(port)) + + udpAddr := net.UDPAddrFromAddrPort(addrPort) + udpLn, err := net.ListenUDP("udp", udpAddr) if err != nil { - log.Warnf("binding dns on %s is not available, error: %s", addrString, err) + log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err) return false } - - err = probeListener.Close() - if err != nil { - log.Errorf("got an error closing the probe listener, error: %s", err) + if err := udpLn.Close(); err != nil { + log.Debugf("close UDP probe listener: %s", err) } + + tcpAddr := net.TCPAddrFromAddrPort(addrPort) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err) + return false + } + if err := tcpLn.Close(); err != nil { + log.Debugf("close TCP probe listener: %s", err) + } + return true } diff --git a/client/internal/dns/service_listener_test.go b/client/internal/dns/service_listener_test.go new file mode 100644 index 000000000..90ef71d19 --- /dev/null +++ b/client/internal/dns/service_listener_test.go @@ -0,0 +1,86 @@ +package dns + +import ( + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServiceViaListener_TCPAndUDP(t *testing.T) { + handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("192.0.2.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // Create a service using a custom address to avoid needing root + svc := newServiceViaListener(nil, nil, nil) + svc.dnsMux.Handle(".", handler) + + // Bind both transports up front to avoid TOCTOU races. + udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0)) + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Skip("cannot bind to 127.0.0.153, skipping") + } + port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port) + + tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port)) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + udpConn.Close() + t.Skip("cannot bind TCP on same port, skipping") + } + + addr := fmt.Sprintf("%s:%d", customIP, port) + svc.server.PacketConn = udpConn + svc.tcpServer.Listener = tcpLn + svc.listenIP = customIP + svc.listenPort = port + + go func() { + if err := svc.server.ActivateAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := svc.tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + svc.listenerIsRunning = true + + defer func() { + require.NoError(t, svc.Stop()) + }() + + q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + // Test UDP query + udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} + udpResp, _, err := udpClient.Exchange(q, addr) + require.NoError(t, err, "UDP query should succeed") + require.NotNil(t, udpResp) + require.NotEmpty(t, udpResp.Answer) + assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP") + + // Test TCP query + tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second} + tcpResp, _, err := tcpClient.Exchange(q, addr) + require.NoError(t, err, "TCP query should succeed") + require.NotNil(t, tcpResp) + require.NotEmpty(t, tcpResp.Answer) + assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP") +} diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 6ef0ab526..e8c036076 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -1,6 +1,7 @@ package dns import ( + "errors" "fmt" "net/netip" "sync" @@ -10,6 +11,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -18,7 +20,8 @@ type ServiceViaMemory struct { dnsMux *dns.ServeMux runtimeIP netip.Addr runtimePort int - udpFilterHookID string + tcpDNS *tcpDNSServer + tcpHookSet bool listenerIsRunning bool listenerFlagLock sync.Mutex } @@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { if err != nil { log.Errorf("get last ip from network: %v", err) } - s := &ServiceViaMemory{ + + return &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: lastIP, runtimePort: DefaultPort, } - return s } func (s *ServiceViaMemory) Listen() error { @@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error { return nil } - var err error - s.udpFilterHookID, err = s.filterDNSTraffic() - if err != nil { - return fmt.Errorf("filter dns traffice: %w", err) + if err := s.filterDNSTraffic(); err != nil { + return fmt.Errorf("filter dns traffic: %w", err) } s.listenerIsRunning = true @@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error { return nil } -func (s *ServiceViaMemory) Stop() { +func (s *ServiceViaMemory) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil } - if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil { - log.Errorf("unable to remove DNS packet hook: %s", err) + filter := s.wgInterface.GetFilter() + if filter != nil { + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + if s.tcpHookSet { + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + } + } + + if s.tcpDNS != nil { + s.tcpDNS.Stop() } s.listenerIsRunning = false + + return nil } func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { @@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr { return s.runtimeIP } -func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { +func (s *ServiceViaMemory) filterDNSTraffic() error { filter := s.wgInterface.GetFilter() if filter == nil { - return "", fmt.Errorf("can't set DNS filter, filter not initialized") + return errors.New("DNS filter not initialized") + } + + // Create TCP DNS server lazily here since the device may not exist at construction time. + if s.tcpDNS == nil { + if dev := s.wgInterface.GetDevice(); dev != nil { + // MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact. + s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU) + } } firstLayerDecoder := layers.LayerTypeIPv4 @@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { } hook := func(packetData []byte) bool { - // Decode the packet packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) - // Get the UDP layer udpLayer := packet.Layer(layers.LayerTypeUDP) - udp := udpLayer.(*layers.UDP) + if udpLayer == nil { + return true + } + udp, ok := udpLayer.(*layers.UDP) + if !ok { + return true + } msg := new(dns.Msg) if err := msg.Unpack(udp.Payload); err != nil { @@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - writer := responseWriter{ - packet: packet, - device: s.wgInterface.GetDevice().Device, + dev := s.wgInterface.GetDevice() + if dev == nil { + return true } - go s.dnsMux.ServeDNS(&writer, msg) + + writer := &responseWriter{ + remote: remoteAddrFromPacket(packet), + packet: packet, + device: dev.Device, + } + go s.dnsMux.ServeDNS(writer, msg) return true } - return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook) + + if s.tcpDNS != nil { + tcpHook := func(packetData []byte) bool { + s.tcpDNS.InjectPacket(packetData) + return true + } + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook) + s.tcpHookSet = true + } + + return nil } diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go new file mode 100644 index 000000000..88e72e767 --- /dev/null +++ b/client/internal/dns/tcpstack.go @@ -0,0 +1,444 @@ +package dns + +import ( + "errors" + "fmt" + "io" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + dnsTCPReceiveWindow = 8192 + dnsTCPMaxInFlight = 16 + dnsTCPIdleTimeout = 30 * time.Second + dnsTCPReadTimeout = 5 * time.Second +) + +// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack. +// It is started lazily when a truncated DNS response is detected and shuts down +// after a period of inactivity to conserve resources. +type tcpDNSServer struct { + mu sync.Mutex + s *stack.Stack + ep *dnsEndpoint + mux *dns.ServeMux + tunDev tun.Device + ip netip.Addr + port uint16 + mtu uint16 + + running bool + closed bool + timerID uint64 + timer *time.Timer +} + +func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer { + return &tcpDNSServer{ + mux: mux, + tunDev: tunDev, + ip: ip, + port: port, + mtu: mtu, + } +} + +// InjectPacket ensures the stack is running and delivers a raw IP packet into +// the gvisor stack for TCP processing. Combining both operations under a single +// lock prevents a race where the idle timer could stop the stack between +// start and delivery. +func (t *tcpDNSServer) InjectPacket(payload []byte) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return + } + + if !t.running { + if err := t.startLocked(); err != nil { + log.Errorf("failed to start TCP DNS stack: %v", err) + return + } + t.running = true + log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload)) + } + t.resetTimerLocked() + + ep := t.ep + if ep == nil || ep.dispatcher == nil { + return + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + // DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef. + ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) +} + +// Stop tears down the gvisor stack and releases resources permanently. +// After Stop, InjectPacket becomes a no-op. +func (t *tcpDNSServer) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopLocked() + t.closed = true +} + +func (t *tcpDNSServer) startLocked() error { + // TODO: add ipv6.NewProtocol when IPv6 overlay support lands. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: false, + }) + + nicID := tcpip.NICID(1) + ep := &dnsEndpoint{ + tunDev: t.tunDev, + } + ep.mtu.Store(uint32(t.mtu)) + + if err := s.CreateNIC(nicID, ep); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create NIC: %v", err) + } + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(t.ip.AsSlice()), + PrefixLen: 32, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("add protocol address: %s", err) + } + + if err := s.SetPromiscuousMode(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set promiscuous mode: %s", err) + } + if err := s.SetSpoofing(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set spoofing: %s", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create default subnet: %w", err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: defaultSubnet, NIC: nicID}, + }) + + tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) { + t.handleTCPDNS(r) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + t.s = s + t.ep = ep + return nil +} + +func (t *tcpDNSServer) stopLocked() { + if !t.running { + return + } + + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } + + if t.s != nil { + t.s.Close() + t.s.Wait() + t.s = nil + } + t.ep = nil + t.running = false + + log.Debugf("TCP DNS stack stopped") +} + +func (t *tcpDNSServer) resetTimerLocked() { + if t.timer != nil { + t.timer.Stop() + } + t.timerID++ + id := t.timerID + t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() { + t.mu.Lock() + defer t.mu.Unlock() + + // Only stop if this timer is still the active one. + // A racing InjectPacket may have replaced it. + if t.timerID != id { + return + } + t.stopLocked() + }) +} + +func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) { + id := r.ID() + + wq := waiter.Queue{} + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + log.Debugf("TCP DNS: failed to create endpoint: %v", epErr) + r.Complete(true) + return + } + r.Complete(false) + + conn := gonet.NewTCPConn(&wq, ep) + defer func() { + if err := conn.Close(); err != nil { + log.Tracef("TCP DNS: close conn: %v", err) + } + }() + + // Reset idle timer on activity + t.mu.Lock() + t.resetTimerLocked() + t.mu.Unlock() + + localAddr := &net.TCPAddr{ + IP: id.LocalAddress.AsSlice(), + Port: int(id.LocalPort), + } + remoteAddr := &net.TCPAddr{ + IP: id.RemoteAddress.AsSlice(), + Port: int(id.RemotePort), + } + + for { + if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil { + log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err) + break + } + + msg, err := readTCPDNSMessage(conn) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err) + } + break + } + + writer := &tcpResponseWriter{ + conn: conn, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + t.mux.ServeDNS(writer, msg) + } +} + +// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device. +type dnsEndpoint struct { + dispatcher stack.NetworkDispatcher + tunDev tun.Device + mtu atomic.Uint32 +} + +func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } +func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil } +func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() } +func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone } +func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 } +func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" } +func (e *dnsEndpoint) Wait() { /* no async work */ } +func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } +func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ } +func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } +func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ } +func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ } +func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) } +func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ } + +const tunPacketOffset = 40 + +func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + raw := data.AsSlice() + buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw)) + buf = append(buf, raw...) + data.Release() + + if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil { + log.Tracef("TCP DNS endpoint: failed to write packet: %v", err) + continue + } + written++ + } + return written, nil +} + +// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections. +type tcpResponseWriter struct { + conn *gonet.TCPConn + localAddr net.Addr + remoteAddr net.Addr +} + +func (w *tcpResponseWriter) LocalAddr() net.Addr { + return w.localAddr +} + +func (w *tcpResponseWriter) RemoteAddr() net.Addr { + return w.remoteAddr +} + +func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error { + data, err := msg.Pack() + if err != nil { + return fmt.Errorf("pack: %w", err) + } + + // DNS TCP: 2-byte length prefix + message + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + + if _, err = w.conn.Write(buf); err != nil { + return err + } + return nil +} + +func (w *tcpResponseWriter) Write(data []byte) (int, error) { + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + if _, err := w.conn.Write(buf); err != nil { + return 0, err + } + return len(data), nil +} + +func (w *tcpResponseWriter) Close() error { + return w.conn.Close() +} + +func (w *tcpResponseWriter) TsigStatus() error { return nil } +func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ } +func (w *tcpResponseWriter) Hijack() { /* not supported */ } + +// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed). +func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) { + // DNS over TCP uses a 2-byte length prefix + lenBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return nil, fmt.Errorf("read length: %w", err) + } + + msgLen := int(lenBuf[0])<<8 | int(lenBuf[1]) + if msgLen == 0 || msgLen > 65535 { + return nil, fmt.Errorf("invalid message length: %d", msgLen) + } + + msgBuf := make([]byte, msgLen) + if _, err := io.ReadFull(conn, msgBuf); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + return nil, fmt.Errorf("unpack: %w", err) + } + return msg, nil +} + +// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging. +// Supports both IPv4 and IPv6. +func srcAddrFromPacket(pkt []byte) netip.AddrPort { + if len(pkt) == 0 { + return netip.AddrPort{} + } + + srcIP, transportOffset := srcIPFromPacket(pkt) + if !srcIP.IsValid() || len(pkt) < transportOffset+2 { + return netip.AddrPort{} + } + + srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1]) + return netip.AddrPortFrom(srcIP.Unmap(), srcPort) +} + +func srcIPFromPacket(pkt []byte) (netip.Addr, int) { + switch header.IPVersion(pkt) { + case 4: + return srcIPv4(pkt) + case 6: + return srcIPv6(pkt) + default: + return netip.Addr{}, 0 + } +} + +func srcIPv4(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv4MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv4(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, int(hdr.HeaderLength()) +} + +func srcIPv6(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv6MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv6(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, header.IPv6MinimumSize +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 0fbd32771..746b73ca7 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -41,10 +41,61 @@ const ( reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second + + // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP + // payload from the tunnel MTU. + ipUDPHeaderSize = 60 + 8 ) const testRecord = "com." +const ( + protoUDP = "udp" + protoTCP = "tcp" +) + +type dnsProtocolKey struct{} + +// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. +func contextWithDNSProtocol(ctx context.Context, network string) context.Context { + return context.WithValue(ctx, dnsProtocolKey{}, network) +} + +// dnsProtocolFromContext retrieves the inbound DNS protocol from context. +func dnsProtocolFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok { + return v + } + return "" +} + +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +// contextWithupstreamProtocolResult stores a mutable result holder in the context. +func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { + r := &upstreamProtocolResult{} + return context.WithValue(ctx, upstreamProtocolKey{}, r), r +} + +// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present. +func setUpstreamProtocol(ctx context.Context, protocol string) { + if ctx == nil { + return + } + if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil { + r.protocol = protocol + } +} + type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } @@ -65,10 +116,12 @@ type upstreamResolverBase struct { mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration + wg sync.WaitGroup deactivate func(error) reactivate func() statusRecorder *peer.Status + routeMatch func(netip.Addr) bool } type upstreamFailure struct { @@ -115,6 +168,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() + + u.mutex.Lock() + u.wg.Wait() + u.mutex.Unlock() + } // ServeDNS handles a DNS request @@ -131,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - ok, failures := u.tryUpstreamServers(w, r, logger) + // Propagate inbound protocol so upstream exchange can use TCP directly + // when the request came in over TCP. + ctx := u.ctx + if addr := w.RemoteAddr(); addr != nil { + network := addr.Network() + ctx = contextWithDNSProtocol(ctx, network) + resutil.SetMeta(w, "protocol", network) + } + + ok, failures := u.tryUpstreamServers(ctx, w, r, logger) if len(failures) > 0 { u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger) } @@ -146,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } } -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { +func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { timeout := u.upstreamTimeout if len(u.upstreamServers) > 1 { maxTotal := 5 * time.Second @@ -161,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M var failures []upstreamFailure for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil { + if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { failures = append(failures, *failure) } else { return true, failures @@ -171,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M } // queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { var rm *dns.Msg var t time.Duration var err error var startTime time.Time + var upstreamProto *upstreamProtocolResult func() { - ctx, cancel := context.WithTimeout(u.ctx, timeout) + ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() + ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) startTime = time.Now() rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) }() @@ -196,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } @@ -213,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { u.successCount.Add(1) resutil.SetMeta(w, "upstream", upstream.String()) + if upstreamProto != nil && upstreamProto.protocol != "" { + resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + } // Clear Zero bit from external responses to prevent upstream servers from // manipulating our internal fallthrough signaling mechanism @@ -260,16 +332,10 @@ func formatFailures(failures []upstreamFailure) string { // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { u.mutex.Lock() defer u.mutex.Unlock() - select { - case <-u.ctx.Done(): - return - default: - } - // avoid probe if upstreams could resolve at least one query if u.successCount.Load() > 0 { return @@ -279,31 +345,39 @@ func (u *upstreamResolverBase) ProbeAvailability() { var mu sync.Mutex var wg sync.WaitGroup - var errors *multierror.Error + var errs *multierror.Error for _, upstream := range u.upstreamServers { - upstream := upstream - wg.Add(1) - go func() { + go func(upstream netip.AddrPort) { defer wg.Done() - err := u.testNameserver(upstream, 500*time.Millisecond) + err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) if err != nil { - errors = multierror.Append(errors, err) + mu.Lock() + errs = multierror.Append(errs, err) + mu.Unlock() log.Warnf("probing upstream nameserver %s: %s", upstream, err) return } mu.Lock() - defer mu.Unlock() success = true - }() + mu.Unlock() + }(upstream) } wg.Wait() + select { + case <-ctx.Done(): + return + case <-u.ctx.Done(): + return + default: + } + // didn't find a working upstream server, let's disable and try later if !success { - u.disable(errors.ErrorOrNil()) + u.disable(errs.ErrorOrNil()) if u.statusRecorder == nil { return @@ -339,7 +413,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream, probeTimeout); err != nil { + if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -351,16 +425,22 @@ func (u *upstreamResolverBase) waitUntilResponse() { return fmt.Errorf("upstream check call error") } - err := backoff.Retry(operation, exponentialBackOff) + err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) if err != nil { - log.Warn(err) + if errors.Is(err, context.Canceled) { + log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) + } else { + log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) + } return } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) u.successCount.Add(1) u.reactivate() + u.mutex.Lock() u.disabled = false + u.mutex.Unlock() } // isTimeout returns true if the given error is a network timeout error. @@ -383,7 +463,11 @@ func (u *upstreamResolverBase) disable(err error) { u.successCount.Store(0) u.deactivate(err) u.disabled = true - go u.waitUntilResponse() + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.waitUntilResponse() + }() } func (u *upstreamResolverBase) upstreamServersString() string { @@ -394,23 +478,57 @@ func (u *upstreamResolverBase) upstreamServersString() string { return strings.Join(servers, ", ") } -func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(u.ctx, timeout) +func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { + mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) defer cancel() + if externalCtx != nil { + stop2 := context.AfterFunc(externalCtx, cancel) + defer stop2() + } + r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - _, _, err := u.upstreamClient.exchange(ctx, server.String(), r) + _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) return err } +// clientUDPMaxSize returns the maximum UDP response size the client accepts. +func clientUDPMaxSize(r *dns.Msg) int { + if opt := r.IsEdns0(); opt != nil { + return int(opt.UDPSize()) + } + return dns.MinMsgSize +} + // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. +// If the inbound request came over TCP (via context), it skips the UDP attempt. // If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { - // MTU - ip + udp headers - // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. - client.UDPSize = uint16(currentMTU - (60 + 8)) + // If the request came in over TCP, go straight to TCP upstream. + if dnsProtocolFromContext(ctx) == protoTCP { + tcpClient := *client + tcpClient.Net = protoTCP + rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + if err != nil { + return nil, t, fmt.Errorf("with tcp: %w", err) + } + setUpstreamProtocol(ctx, protoTCP) + return rm, t, nil + } + + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than our read buffer. + // Note: the query could be sent out on an interface that is not ours, + // but higher MTU settings could break truncation handling. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + client.UDPSize = maxUDPPayload + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } var ( rm *dns.Msg @@ -429,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u } if rm == nil || !rm.MsgHdr.Truncated { + setUpstreamProtocol(ctx, protoUDP) return rm, t, nil } - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + // TODO: if the upstream's truncated UDP response already contains more + // data than the client's buffer, we could truncate locally and skip + // the TCP retry. - client.Net = "tcp" + tcpClient := *client + tcpClient.Net = protoTCP if ctx == nil { - rm, t, err = client.Exchange(r, upstream) + rm, t, err = tcpClient.Exchange(r, upstream) } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) + rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) } if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } - // TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP + setUpstreamProtocol(ctx, protoTCP) + + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) + } return rm, t, nil } @@ -455,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { - reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp") + // If request came in over TCP, go straight to TCP upstream + if dnsProtocolFromContext(ctx) == protoTCP { + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + setUpstreamProtocol(ctx, protoTCP) + return rm, nil + } + + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than what we can read over UDP. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } + + reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP) if err != nil { return nil, err } - // If response is truncated, retry with TCP if reply != nil && reply.MsgHdr.Truncated { - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - return netstackExchange(ctx, nsNet, r, upstream, "tcp") + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + + setUpstreamProtocol(ctx, protoTCP) + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) + } + + return rm, nil } + setUpstreamProtocol(ctx, protoUDP) + return reply, nil } @@ -487,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst } } - dnsConn := &dns.Conn{Conn: conn} + dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)} if err := dnsConn.WriteMsg(r); err != nil { return nil, fmt.Errorf("write %s message: %w", network, err) diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index d7cff377b..ee1ca42fe 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin upstreamExchangeClient := &dns.Client{ Timeout: ClientTimeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } // exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN @@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri Timeout: timeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 4d053a5a1..02c11173b 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { - log.Debugf("using private client to query upstream: %s", upstream) + needsPrivate := u.lNet.Contains(upstreamIP) || + (u.routeMatch != nil && u.routeMatch(upstreamIP)) + if needsPrivate { + log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) if err != nil { - return nil, 0, fmt.Errorf("error while creating private client: %s", err) + return nil, 0, fmt.Errorf("create private client: %s", err) } } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 8b06e4475..1797fdad8 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ProbeAvailability() + resolver.ProbeAvailability(context.TODO()) if !failed { t.Errorf("expected that resolving was deactivated") @@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) { }) } } + +func TestDNSProtocolContext(t *testing.T) { + t.Run("roundtrip udp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoUDP) + assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx)) + }) + + t.Run("roundtrip tcp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx)) + }) + + t.Run("missing returns empty", func(t *testing.T) { + assert.Equal(t, "", dnsProtocolFromContext(context.Background())) + }) +} + +func TestExchangeWithFallback_TCPContext(t *testing.T) { + // Start a local DNS server that responds on TCP only + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpServer := &dns.Server{ + Addr: "127.0.0.1:0", + Net: "tcp", + Handler: tcpHandler, + } + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tcpServer.Listener = tcpLn + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // With TCP context, should connect directly via TCP without trying UDP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.1") +} + +func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { + // UDP handler returns a truncated response to trigger TCP retry. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // TCP handler returns the full answer. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.3"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{ + PacketConn: udpPC, + Net: "udp", + Handler: udpHandler, + } + + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: tcpHandler, + } + + go func() { + if err := udpServer.ActivateAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }() + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err, "should fall back to TCP after truncated UDP response") + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer") + assert.Contains(t, rm.Answer[0].String(), "10.0.0.3") + assert.False(t, rm.Truncated, "TCP response should not be truncated") +} + +func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) { + // Start only a TCP server (no UDP). With TCP context it should succeed. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.2"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: tcpHandler, + } + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // TCP context: should skip UDP entirely and go directly to TCP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.2") + + // Without TCP context, trying to reach a TCP-only server via UDP should fail + ctx2 := context.Background() + client2 := &dns.Client{Timeout: 500 * time.Millisecond} + _, _, err = ExchangeWithFallback(ctx2, client2, r, upstream) + assert.Error(t, err, "should fail when no UDP server and no TCP context") +} + +func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { + // Verify that a client EDNS0 larger than our MTU-derived limit gets + // capped in the outgoing request so the upstream doesn't send a + // response larger than our read buffer. + var receivedUDPSize uint16 + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + if opt := r.IsEdns0(); opt != nil { + receivedUDPSize = opt.UDPSize() + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + go func() { _ = udpServer.ActivateAndServe() }() + t.Cleanup(func() { _ = udpServer.Shutdown() }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + + expectedMax := uint16(currentMTU - ipUDPHeaderSize) + assert.Equal(t, expectedMax, receivedUDPSize, + "upstream should see capped EDNS0, not the client's 4096") +} + +func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { + // When the client advertises a large EDNS0 (4096) and the upstream + // truncates, the TCP response should NOT be truncated since the full + // answer fits within the client's original buffer. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + // Add enough records to exceed MTU but fit within 4096 + for i := range 20 { + m.Answer = append(m.Answer, &dns.TXT{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60}, + Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)}, + }) + } + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler} + + go func() { _ = udpServer.ActivateAndServe() }() + go func() { _ = tcpServer.ActivateAndServe() }() + t.Cleanup(func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + + // Client with large buffer: should get all records without truncation + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records") + assert.False(t, rm.Truncated, "response should not be truncated for large buffer client") + + // Client with small buffer: should get truncated response + r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r2.SetEdns0(512, false) + + rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr) + require.NoError(t, err) + require.NotNil(t, rm2) + assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") + assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 5c7cb31fc..2e8ef84ab 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re return } - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime)) } // udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. @@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, w, query, startTime) } diff --git a/client/internal/engine.go b/client/internal/engine.go index f2d724aa4..b49e02c6d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -36,7 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/internal/ingressgw" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/netflow" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/networkmonitor" @@ -44,22 +46,21 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/jobexec" cProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/shared/management/domain" - semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" - "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" mgm "github.com/netbirdio/netbird/shared/management/client" + "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/shared/relay/client" @@ -75,13 +76,11 @@ import ( const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms - connInitLimit = 200 disableAutoUpdate = "disabled" ) var ErrResetConnection = fmt.Errorf("reset connection") -// EngineConfig is a config for the Engine type EngineConfig struct { WgPort int WgIfaceName string @@ -141,6 +140,19 @@ type EngineConfig struct { ProfileConfig *profilemanager.Config LogPath string + TempDir string +} + +// EngineServices holds the external service dependencies required by the Engine. +type EngineServices struct { + SignalClient signal.Client + MgmClient mgm.Client + RelayManager *relayClient.Manager + StatusRecorder *peer.Status + Checks []*mgmProto.Checks + StateManager *statemanager.Manager + UpdateManager *updater.Manager + ClientMetrics *metrics.ClientMetrics } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -200,19 +212,19 @@ type Engine struct { // checks are the client-applied posture checks that need to be evaluated on the client checks []*mgmProto.Checks - relayManager *relayClient.Manager - stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + relayManager *relayClient.Manager + stateManager *statemanager.Manager + portForwardManager *portforward.Manager + srWatcher *guard.SRWatcher // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex persistSyncResponse bool latestSyncResponse *mgmProto.SyncResponse - connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager // auto-update - updateManager *updatemanager.Manager + updateManager *updater.Manager // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor @@ -222,8 +234,13 @@ type Engine struct { probeStunTurn *relay.StunTurnProbe + // clientMetrics collects and pushes metrics + clientMetrics *metrics.ClientMetrics + jobExecutor *jobexec.Executor jobExecutorWG sync.WaitGroup + + exposeManager *expose.Manager } // Peer is an instance of the Connection Peer @@ -240,35 +257,32 @@ type localIpUpdater interface { func NewEngine( clientCtx context.Context, clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, config *EngineConfig, + services EngineServices, mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, - stateManager *statemanager.Manager, ) *Engine { engine := &Engine{ - clientCtx: clientCtx, - clientCancel: clientCancel, - signal: signalClient, - signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), - mgmClient: mgmClient, - relayManager: relayManager, - peerStore: peerstore.NewConnStore(), - syncMsgMux: &sync.Mutex{}, - config: config, - mobileDep: mobileDep, - STUNs: []*stun.URI{}, - TURNs: []*stun.URI{}, - networkSerial: 0, - statusRecorder: statusRecorder, - stateManager: stateManager, - checks: checks, - connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), - jobExecutor: jobexec.NewExecutor(), + clientCtx: clientCtx, + clientCancel: clientCancel, + signal: services.SignalClient, + signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), + mgmClient: services.MgmClient, + relayManager: services.RelayManager, + peerStore: peerstore.NewConnStore(), + syncMsgMux: &sync.Mutex{}, + config: config, + mobileDep: mobileDep, + STUNs: []*stun.URI{}, + TURNs: []*stun.URI{}, + networkSerial: 0, + statusRecorder: services.StatusRecorder, + stateManager: services.StateManager, + portForwardManager: portforward.NewManager(), + checks: services.Checks, + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), + jobExecutor: jobexec.NewExecutor(), + clientMetrics: services.ClientMetrics, + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -311,7 +325,7 @@ func (e *Engine) Stop() error { } if e.updateManager != nil { - e.updateManager.Stop() + e.updateManager.SetDownloadOnly() } log.Info("cleaning up status recorder states") @@ -419,6 +433,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.cancel() } e.ctx, e.cancel = context.WithCancel(e.clientCtx) + e.exposeManager = expose.NewManager(e.ctx, e.mgmClient) wgIface, err := e.newWgIface() if err != nil { @@ -488,6 +503,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) + e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { + for _, routes := range e.routeManager.GetSelectedClientRoutes() { + for _, r := range routes { + if r.Network.Contains(ip) { + return true + } + } + } + return false + }) + if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) e.close() @@ -499,6 +525,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return err } + // Inject firewall into DNS server now that it's available. + // The DNS server is created before the firewall because the route manager + // depends on the DNS server, and the firewall depends on the wg interface. + e.dnsServer.SetFirewall(e.firewall) + e.udpMux, err = e.wgInterface.Up() if err != nil { log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) @@ -510,6 +541,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // conntrack entries from being created before the rules are in place e.setupWGProxyNoTrack() + // Start after interface is up since port may have been resolved from 0 or changed if occupied + e.shutdownWg.Add(1) + go func() { + defer e.shutdownWg.Done() + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + }() + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -560,13 +598,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } -func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - - e.handleAutoUpdateVersion(autoUpdateSettings, true) -} - func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -794,45 +825,30 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } -func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + if e.updateManager == nil { + return + } + if autoUpdateSettings == nil { return } - disabled := autoUpdateSettings.Version == disableAutoUpdate - - // Stop and cleanup if disabled - if e.updateManager != nil && disabled { - log.Infof("auto-update is disabled, stopping update manager") - e.updateManager.Stop() - e.updateManager = nil + if autoUpdateSettings.Version == disableAutoUpdate { + log.Infof("auto-update is disabled") + e.updateManager.SetDownloadOnly() return } - // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup - if !autoUpdateSettings.AlwaysUpdate && !initialCheck { - log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") - return - } - - // Start manager if needed - if e.updateManager == nil { - log.Infof("starting auto-update manager") - updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) - if err != nil { - return - } - e.updateManager = updateManager - e.updateManager.Start(e.ctx) - } - log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) - e.updateManager.SetVersion(autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate) } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { started := time.Now() defer func() { - log.Infof("sync finished in %s", time.Since(started)) + duration := time.Since(started) + log.Infof("sync finished in %s", duration) + e.clientMetrics.RecordSyncDuration(e.ctx, duration) }() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -843,7 +859,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { } if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { - e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) } if update.GetNetbirdConfig() != nil { @@ -1008,10 +1024,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return errors.New("wireguard interface is not initialized") } - // Cannot update the IP address without restarting the engine because - // the firewall, route manager, and other components cache the old address if e.wgInterface.Address().String() != conf.Address { - log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address) + log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address) + _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) + e.clientCancel() + return ErrResetConnection } if conf.GetSshConfig() != nil { @@ -1079,6 +1096,8 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR StatusRecorder: e.statusRecorder, SyncResponse: syncResponse, LogPath: e.config.LogPath, + TempDir: e.config.TempDir, + ClientMetrics: e.clientMetrics, RefreshStatus: func() { e.RunHealthProbes(true) }, @@ -1316,8 +1335,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // Test received (upstream) servers for availability right away instead of upon usage. // If no server of a server group responds this will disable the respective handler and retry later. - e.dnsServer.ProbeAvailability() - + go e.dnsServer.ProbeAvailability() return nil } @@ -1534,12 +1552,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV } serviceDependencies := peer.ServiceDependencies{ - StatusRecorder: e.statusRecorder, - Signaler: e.signaler, - IFaceDiscover: e.mobileDep.IFaceDiscover, - RelayManager: e.relayManager, - SrWatcher: e.srWatcher, - Semaphore: e.connSemaphore, + StatusRecorder: e.statusRecorder, + Signaler: e.signaler, + IFaceDiscover: e.mobileDep.IFaceDiscover, + RelayManager: e.relayManager, + SrWatcher: e.srWatcher, + PortForwardManager: e.portForwardManager, + MetricsRecorder: e.clientMetrics, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1696,6 +1715,12 @@ func (e *Engine) close() { if e.rpManager != nil { _ = e.rpManager.Close() } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.portForwardManager.GracefullyStop(ctx); err != nil { + log.Warnf("failed to gracefully stop port forwarding manager: %s", err) + } } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { @@ -1799,7 +1824,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: @@ -1824,11 +1849,28 @@ func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager } -// GetFirewallManager returns the firewall manager +// GetFirewallManager returns the firewall manager. func (e *Engine) GetFirewallManager() firewallManager.Manager { return e.firewall } +// GetExposeManager returns the expose session manager. +func (e *Engine) GetExposeManager() *expose.Manager { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + return e.exposeManager +} + +// IsBlockInbound returns whether inbound connections are blocked. +func (e *Engine) IsBlockInbound() bool { + return e.config.BlockInbound +} + +// GetClientMetrics returns the client metrics +func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { + return e.clientMetrics +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 012c8ad6e..9fa4e51b2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -55,6 +55,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -251,9 +252,6 @@ func TestEngine_SSH(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine( ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, &EngineConfig{ WgIfaceName: "utun101", WgAddr: "100.64.0.1/24", @@ -263,10 +261,13 @@ func TestEngine_SSH(t *testing.T) { MTU: iface.DefaultMTU, SSHKey: sshKey, }, + EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil, - nil, ) engine.dnsServer = &dns.MockServer{ @@ -428,13 +429,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun102", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -647,13 +653,18 @@ func TestEngine_Sync(t *testing.T) { return nil } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{SyncFunc: syncFunc}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -812,13 +823,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { @@ -1014,13 +1030,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) @@ -1518,13 +1539,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin return nil, err } - publicKey, err := mgmtClient.GetServerPublicKey() - if err != nil { - return nil, err - } - info := system.GetInfo(ctx) - resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) + resp, err := mgmtClient.Register(setupKey, "", info, nil, nil) if err != nil { return nil, err } @@ -1546,7 +1562,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil + e, err := NewEngine(ctx, cancel, conf, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmtClient, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}), nil e.ctx = ctx return e, err } @@ -1614,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1636,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go new file mode 100644 index 000000000..076f92043 --- /dev/null +++ b/client/internal/expose/manager.go @@ -0,0 +1,104 @@ +package expose + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + + mgm "github.com/netbirdio/netbird/shared/management/client" +) + +const ( + renewTimeout = 10 * time.Second +) + +// Response holds the response from exposing a service. +type Response struct { + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool +} + +// Request holds the parameters for exposing a local service via the management server. +// It is part of the embed API surface and exposed via a type alias. +type Request struct { + NamePrefix string + Domain string + Port uint16 + Protocol ProtocolType + Pin string + Password string + UserGroups []string + ListenPort uint16 +} + +type ManagementClient interface { + CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) + RenewExpose(ctx context.Context, domain string) error + StopExpose(ctx context.Context, domain string) error +} + +// Manager handles expose session lifecycle via the management client. +type Manager struct { + mgmClient ManagementClient + ctx context.Context +} + +// NewManager creates a new expose Manager using the given management client. +func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager { + return &Manager{mgmClient: mgmClient, ctx: ctx} +} + +// Expose creates a new expose session via the management server. +func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) { + log.Infof("exposing service on port %d", req.Port) + resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req)) + if err != nil { + return nil, err + } + + log.Infof("expose session created for %s", resp.Domain) + + return fromClientExposeResponse(resp), nil +} + +// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs. +// It is part of the embed API surface and exposed via a type alias. +func (m *Manager) KeepAlive(ctx context.Context, domain string) error { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + defer m.stop(domain) + + for { + select { + case <-ctx.Done(): + log.Infof("context canceled, stopping keep alive for %s", domain) + + return nil + case <-ticker.C: + if err := m.renew(ctx, domain); err != nil { + log.Errorf("renewing expose session for %s: %v", domain, err) + return err + } + } + } +} + +// renew extends the TTL of an active expose session. +func (m *Manager) renew(ctx context.Context, domain string) error { + renewCtx, cancel := context.WithTimeout(ctx, renewTimeout) + defer cancel() + return m.mgmClient.RenewExpose(renewCtx, domain) +} + +// stop terminates an active expose session. +func (m *Manager) stop(domain string) { + stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout) + defer cancel() + err := m.mgmClient.StopExpose(stopCtx, domain) + if err != nil { + log.Warnf("Failed stopping expose session for %s: %v", domain, err) + } +} diff --git a/client/internal/expose/manager_test.go b/client/internal/expose/manager_test.go new file mode 100644 index 000000000..7d76c9838 --- /dev/null +++ b/client/internal/expose/manager_test.go @@ -0,0 +1,95 @@ +package expose + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + daemonProto "github.com/netbirdio/netbird/client/proto" + mgm "github.com/netbirdio/netbird/shared/management/client" +) + +func TestManager_Expose_Success(t *testing.T) { + mock := &mgm.MockClient{ + CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) { + return &mgm.ExposeResponse{ + ServiceName: "my-service", + ServiceURL: "https://my-service.example.com", + Domain: "my-service.example.com", + }, nil + }, + } + + m := NewManager(context.Background(), mock) + result, err := m.Expose(context.Background(), Request{Port: 8080}) + require.NoError(t, err) + assert.Equal(t, "my-service", result.ServiceName, "service name should match") + assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match") + assert.Equal(t, "my-service.example.com", result.Domain, "domain should match") +} + +func TestManager_Expose_Error(t *testing.T) { + mock := &mgm.MockClient{ + CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) { + return nil, errors.New("permission denied") + }, + } + + m := NewManager(context.Background(), mock) + _, err := m.Expose(context.Background(), Request{Port: 8080}) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission denied", "error should propagate") +} + +func TestManager_Renew_Success(t *testing.T) { + mock := &mgm.MockClient{ + RenewExposeFunc: func(ctx context.Context, domain string) error { + assert.Equal(t, "my-service.example.com", domain, "domain should be passed through") + return nil + }, + } + + m := NewManager(context.Background(), mock) + err := m.renew(context.Background(), "my-service.example.com") + require.NoError(t, err) +} + +func TestManager_Renew_Timeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + mock := &mgm.MockClient{ + RenewExposeFunc: func(ctx context.Context, domain string) error { + return ctx.Err() + }, + } + + m := NewManager(ctx, mock) + err := m.renew(ctx, "my-service.example.com") + require.Error(t, err) +} + +func TestNewRequest(t *testing.T) { + req := &daemonProto.ExposeServiceRequest{ + Port: 8080, + Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS, + Pin: "123456", + Password: "secret", + UserGroups: []string{"group1", "group2"}, + Domain: "custom.example.com", + NamePrefix: "my-prefix", + } + + exposeReq := NewRequest(req) + + assert.Equal(t, uint16(8080), exposeReq.Port, "port should match") + assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match") + assert.Equal(t, "123456", exposeReq.Pin, "pin should match") + assert.Equal(t, "secret", exposeReq.Password, "password should match") + assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match") + assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match") + assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match") +} diff --git a/client/internal/expose/protocol.go b/client/internal/expose/protocol.go new file mode 100644 index 000000000..d5026d51e --- /dev/null +++ b/client/internal/expose/protocol.go @@ -0,0 +1,40 @@ +package expose + +import ( + "fmt" + "strings" +) + +// ProtocolType represents the protocol used for exposing a service. +type ProtocolType int + +const ( + // ProtocolHTTP exposes the service as HTTP. + ProtocolHTTP ProtocolType = 0 + // ProtocolHTTPS exposes the service as HTTPS. + ProtocolHTTPS ProtocolType = 1 + // ProtocolTCP exposes the service as TCP. + ProtocolTCP ProtocolType = 2 + // ProtocolUDP exposes the service as UDP. + ProtocolUDP ProtocolType = 3 + // ProtocolTLS exposes the service as TLS. + ProtocolTLS ProtocolType = 4 +) + +// ParseProtocolType parses a protocol string into a ProtocolType. +func ParseProtocolType(s string) (ProtocolType, error) { + switch strings.ToLower(s) { + case "http": + return ProtocolHTTP, nil + case "https": + return ProtocolHTTPS, nil + case "tcp": + return ProtocolTCP, nil + case "udp": + return ProtocolUDP, nil + case "tls": + return ProtocolTLS, nil + default: + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s) + } +} diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go new file mode 100644 index 000000000..ec75bb276 --- /dev/null +++ b/client/internal/expose/request.go @@ -0,0 +1,42 @@ +package expose + +import ( + daemonProto "github.com/netbirdio/netbird/client/proto" + mgm "github.com/netbirdio/netbird/shared/management/client" +) + +// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest. +func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { + return &Request{ + Port: uint16(req.Port), + Protocol: ProtocolType(req.Protocol), + Pin: req.Pin, + Password: req.Password, + UserGroups: req.UserGroups, + Domain: req.Domain, + NamePrefix: req.NamePrefix, + ListenPort: uint16(req.ListenPort), + } +} + +func toClientExposeRequest(req Request) mgm.ExposeRequest { + return mgm.ExposeRequest{ + NamePrefix: req.NamePrefix, + Domain: req.Domain, + Port: req.Port, + Protocol: int(req.Protocol), + Pin: req.Pin, + Password: req.Password, + UserGroups: req.UserGroups, + ListenPort: req.ListenPort, + } +} + +func fromClientExposeResponse(response *mgm.ExposeResponse) *Response { + return &Response{ + ServiceName: response.ServiceName, + Domain: response.Domain, + ServiceURL: response.ServiceURL, + PortAutoAssigned: response.PortAutoAssigned, + } +} diff --git a/client/internal/metrics/connection_type.go b/client/internal/metrics/connection_type.go new file mode 100644 index 000000000..a3406a6b8 --- /dev/null +++ b/client/internal/metrics/connection_type.go @@ -0,0 +1,17 @@ +package metrics + +// ConnectionType represents the type of peer connection +type ConnectionType string + +const ( + // ConnectionTypeICE represents a direct peer-to-peer connection using ICE + ConnectionTypeICE ConnectionType = "ice" + + // ConnectionTypeRelay represents a relayed connection + ConnectionTypeRelay ConnectionType = "relay" +) + +// String returns the string representation of the connection type +func (c ConnectionType) String() string { + return string(c) +} diff --git a/client/internal/metrics/deployment_type.go b/client/internal/metrics/deployment_type.go new file mode 100644 index 000000000..141173cb8 --- /dev/null +++ b/client/internal/metrics/deployment_type.go @@ -0,0 +1,51 @@ +package metrics + +import ( + "net/url" + "strings" +) + +// DeploymentType represents the type of NetBird deployment +type DeploymentType int + +const ( + // DeploymentTypeUnknown represents an unknown or uninitialized deployment type + DeploymentTypeUnknown DeploymentType = iota + + // DeploymentTypeCloud represents a cloud-hosted NetBird deployment + DeploymentTypeCloud + + // DeploymentTypeSelfHosted represents a self-hosted NetBird deployment + DeploymentTypeSelfHosted +) + +// String returns the string representation of the deployment type +func (d DeploymentType) String() string { + switch d { + case DeploymentTypeCloud: + return "cloud" + case DeploymentTypeSelfHosted: + return "selfhosted" + default: + return "unknown" + } +} + +// DetermineDeploymentType determines if the deployment is cloud or self-hosted +// based on the management URL string +func DetermineDeploymentType(managementURL string) DeploymentType { + if managementURL == "" { + return DeploymentTypeUnknown + } + + u, err := url.Parse(managementURL) + if err != nil { + return DeploymentTypeSelfHosted + } + + if strings.ToLower(u.Hostname()) == "api.netbird.io" { + return DeploymentTypeCloud + } + + return DeploymentTypeSelfHosted +} diff --git a/client/internal/metrics/env.go b/client/internal/metrics/env.go new file mode 100644 index 000000000..1f06ce484 --- /dev/null +++ b/client/internal/metrics/env.go @@ -0,0 +1,93 @@ +package metrics + +import ( + "net/url" + "os" + "strconv" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend. + // Metrics collection itself is always active (for debug bundles). + // Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push. + EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED" + + // EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending + EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING" + + // EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress + EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL" + + // EnvMetricsServerURL is the environment variable to override the metrics server address. + // When set, this takes precedence over the server_url from remote push config. + EnvMetricsServerURL = "NB_METRICS_SERVER_URL" + + // EnvMetricsInterval overrides the push interval from the remote config. + // Only affects how often metrics are pushed; remote config availability + // and version range checks are still respected. + // Format: duration string like "1h", "30m", "4h" + EnvMetricsInterval = "NB_METRICS_INTERVAL" + + defaultMetricsConfigURL = "https://ingest.netbird.io/config" +) + +// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var. +// Disabled by default. Metrics collection is always active for debug bundles. +func IsMetricsPushEnabled() bool { + enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled)) + return enabled +} + +// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var. +// Returns 0 if not set or invalid. +func getMetricsInterval() time.Duration { + intervalStr := os.Getenv(EnvMetricsInterval) + if intervalStr == "" { + return 0 + } + interval, err := time.ParseDuration(intervalStr) + if err != nil { + log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err) + return 0 + } + if interval <= 0 { + log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr) + return 0 + } + return interval +} + +func isForceSending() bool { + force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending)) + return force +} + +// getMetricsConfigURL returns the URL to fetch push configuration from +func getMetricsConfigURL() string { + if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" { + return envURL + } + return defaultMetricsConfigURL +} + +// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var. +// Returns nil if not set or invalid. +func getMetricsServerURL() *url.URL { + envURL := os.Getenv(EnvMetricsServerURL) + if envURL == "" { + return nil + } + parsed, err := url.ParseRequestURI(envURL) + if err != nil || parsed.Host == "" { + log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL) + return nil + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme) + return nil + } + return parsed +} diff --git a/client/internal/metrics/influxdb.go b/client/internal/metrics/influxdb.go new file mode 100644 index 000000000..531f6a986 --- /dev/null +++ b/client/internal/metrics/influxdb.go @@ -0,0 +1,219 @@ +package metrics + +import ( + "context" + "fmt" + "io" + "maps" + "slices" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days + maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB + // estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp) + estimatedSampleSize = 256 +) + +// influxSample is a single InfluxDB line protocol entry. +type influxSample struct { + measurement string + tags string + fields map[string]float64 + timestamp time.Time +} + +// influxDBMetrics collects metric events as timestamped samples. +// Each event is recorded with its exact timestamp, pushed once, then cleared. +type influxDBMetrics struct { + mu sync.Mutex + samples []influxSample +} + +func newInfluxDBMetrics() metricsImplementation { + return &influxDBMetrics{} +} +func (m *influxDBMetrics) RecordConnectionStages( + _ context.Context, + agentInfo AgentInfo, + connectionPairID string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, +) { + var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64 + + if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() { + signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds() + } + + if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() { + connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds() + } + + if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() { + totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds() + } + + attemptType := "initial" + if isReconnection { + attemptType = "reconnection" + } + + connTypeStr := connectionType.String() + tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s", + agentInfo.DeploymentType.String(), + connTypeStr, + attemptType, + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + connectionPairID, + ) + + now := time.Now() + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_peer_connection", + tags: tags, + fields: map[string]float64{ + "signaling_to_connection_seconds": signalingReceivedToConnection, + "connection_to_wg_handshake_seconds": connectionToWgHandshake, + "total_seconds": totalDuration, + }, + timestamp: now, + }) + m.trimLocked() + + log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs", + agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration) +} + +func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) { + tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s", + agentInfo.DeploymentType.String(), + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + ) + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_sync", + tags: tags, + fields: map[string]float64{ + "duration_seconds": duration.Seconds(), + }, + timestamp: time.Now(), + }) + m.trimLocked() +} + +func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) { + result := "success" + if !success { + result = "failure" + } + + tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s", + agentInfo.DeploymentType.String(), + result, + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + ) + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_login", + tags: tags, + fields: map[string]float64{ + "duration_seconds": duration.Seconds(), + }, + timestamp: time.Now(), + }) + m.trimLocked() + + log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds()) +} + +// Export writes pending samples in InfluxDB line protocol format. +// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns +func (m *influxDBMetrics) Export(w io.Writer) error { + m.mu.Lock() + samples := make([]influxSample, len(m.samples)) + copy(samples, m.samples) + m.mu.Unlock() + + for _, s := range samples { + if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil { + return err + } + + sortedKeys := slices.Sorted(maps.Keys(s.fields)) + first := true + for _, k := range sortedKeys { + if !first { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil { + return err + } + first = false + } + + if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil { + return err + } + } + return nil +} + +// Reset clears pending samples after a successful push +func (m *influxDBMetrics) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.samples = m.samples[:0] +} + +// trimLocked removes samples that exceed age or size limits. +// Must be called with m.mu held. +func (m *influxDBMetrics) trimLocked() { + now := time.Now() + + // drop samples older than maxSampleAge + cutoff := 0 + for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge { + cutoff++ + } + if cutoff > 0 { + copy(m.samples, m.samples[cutoff:]) + m.samples = m.samples[:len(m.samples)-cutoff] + log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge) + } + + // drop oldest samples if estimated size exceeds maxBufferSize + maxSamples := maxBufferSize / estimatedSampleSize + if len(m.samples) > maxSamples { + drop := len(m.samples) - maxSamples + copy(m.samples, m.samples[drop:]) + m.samples = m.samples[:maxSamples] + log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024)) + } +} diff --git a/client/internal/metrics/influxdb_test.go b/client/internal/metrics/influxdb_test.go new file mode 100644 index 000000000..b964e31a3 --- /dev/null +++ b/client/internal/metrics/influxdb_test.go @@ -0,0 +1,229 @@ +package metrics + +import ( + "bytes" + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInfluxDBMetrics_RecordAndExport(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + ts := ConnectionStageTimestamps{ + SignalingReceived: time.Now().Add(-3 * time.Second), + ConnectionReady: time.Now().Add(-2 * time.Second), + WgHandshakeSuccess: time.Now().Add(-1 * time.Second), + } + + m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_peer_connection,") + assert.Contains(t, output, "connection_to_wg_handshake_seconds=") + assert.Contains(t, output, "signaling_to_connection_seconds=") + assert.Contains(t, output, "total_seconds=") +} + +func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + ts := ConnectionStageTimestamps{ + SignalingReceived: time.Now().Add(-3 * time.Second), + ConnectionReady: time.Now().Add(-2 * time.Second), + WgHandshakeSuccess: time.Now().Add(-1 * time.Second), + } + + // Record multiple times and verify consistent field order + for i := 0; i < 10; i++ { + m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts) + } + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.Len(t, lines, 10) + + // Extract field portion from each line and verify they're all identical + var fieldSections []string + for _, line := range lines { + parts := strings.SplitN(line, " ", 3) + require.Len(t, parts, 3, "each line should have measurement, fields, timestamp") + fieldSections = append(fieldSections, parts[1]) + } + + for i := 1; i < len(fieldSections); i++ { + assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples") + } + + // Fields should be alphabetically sorted + assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="), + "fields should be sorted: connection_to_wg < signaling_to < total") +} + +func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeSelfHosted, + Version: "2.0.0", + OS: "darwin", + Arch: "arm64", + peerID: "def456", + } + + m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_sync,") + assert.Contains(t, output, "duration_seconds=1.5") + assert.Contains(t, output, "deployment_type=selfhosted") +} + +func TestInfluxDBMetrics_Reset(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + m.RecordSyncDuration(context.Background(), agentInfo, time.Second) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + assert.NotEmpty(t, buf.String()) + + m.Reset() + + buf.Reset() + err = m.Export(&buf) + require.NoError(t, err) + assert.Empty(t, buf.String(), "should be empty after reset") +} + +func TestInfluxDBMetrics_ExportEmpty(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + assert.Empty(t, buf.String()) +} + +func TestInfluxDBMetrics_TrimByAge(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + m.mu.Lock() + m.samples = append(m.samples, influxSample{ + measurement: "old", + tags: "t=1", + fields: map[string]float64{"v": 1}, + timestamp: time.Now().Add(-maxSampleAge - time.Hour), + }) + m.trimLocked() + remaining := len(m.samples) + m.mu.Unlock() + + assert.Equal(t, 0, remaining, "old samples should be trimmed") +} + +func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_login,") + assert.Contains(t, output, "duration_seconds=2.5") + assert.Contains(t, output, "result=success") +} + +func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeSelfHosted, + Version: "1.0.0", + OS: "darwin", + Arch: "arm64", + peerID: "xyz789", + } + + m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_login,") + assert.Contains(t, output, "result=failure") + assert.Contains(t, output, "deployment_type=selfhosted") +} + +func TestInfluxDBMetrics_TrimBySize(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + maxSamples := maxBufferSize / estimatedSampleSize + m.mu.Lock() + for i := 0; i < maxSamples+100; i++ { + m.samples = append(m.samples, influxSample{ + measurement: "test", + tags: "t=1", + fields: map[string]float64{"v": float64(i)}, + timestamp: time.Now(), + }) + } + m.trimLocked() + remaining := len(m.samples) + m.mu.Unlock() + + assert.Equal(t, maxSamples, remaining, "should trim to max samples") +} diff --git a/client/internal/metrics/infra/.env.example b/client/internal/metrics/infra/.env.example new file mode 100644 index 000000000..9c5c1a258 --- /dev/null +++ b/client/internal/metrics/infra/.env.example @@ -0,0 +1,16 @@ +# Copy to .env and adjust values before running docker compose + +# InfluxDB admin (server-side only, never exposed to clients) +INFLUXDB_ADMIN_PASSWORD=changeme +INFLUXDB_ADMIN_TOKEN=changeme + +# Grafana admin credentials +GRAFANA_ADMIN_USER=admin +GRAFANA_ADMIN_PASSWORD=changeme + +# Remote config served by ingest at /config +# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable +CONFIG_METRICS_SERVER_URL= +CONFIG_VERSION_SINCE=0.0.0 +CONFIG_VERSION_UNTIL=99.99.99 +CONFIG_PERIOD_MINUTES=5 diff --git a/client/internal/metrics/infra/.gitignore b/client/internal/metrics/infra/.gitignore new file mode 100644 index 000000000..4c49bd78f --- /dev/null +++ b/client/internal/metrics/infra/.gitignore @@ -0,0 +1 @@ +.env diff --git a/client/internal/metrics/infra/README.md b/client/internal/metrics/infra/README.md new file mode 100644 index 000000000..5a93dbd87 --- /dev/null +++ b/client/internal/metrics/infra/README.md @@ -0,0 +1,194 @@ +# Client Metrics + +Internal documentation for the NetBird client metrics system. + +## Overview + +Client metrics track connection performance and sync durations using InfluxDB line protocol (`influxdb.go`). Each event is pushed once then cleared. + +Metrics collection is always active (for debug bundles). Push to backend is: +- Disabled by default (opt-in via `NB_METRICS_PUSH_ENABLED=true`) +- Managed at daemon layer (survives engine restarts) + +## Architecture + +### Layer Separation + +```text +Daemon Layer (connect.go) + ├─ Creates ClientMetrics instance once + ├─ Starts/stops push lifecycle + └─ Updates AgentInfo on profile switch + │ + ▼ +Engine Layer (engine.go) + └─ Records metrics via ClientMetrics methods +``` + +### Ingest Server + +Clients do not talk to InfluxDB directly. An ingest server sits between clients and InfluxDB: + +```text +Client ──POST──▶ Ingest Server (:8087) ──▶ InfluxDB (internal) + │ + ├─ Validates line protocol + ├─ Allowlists measurements, fields, and tags + ├─ Rejects out-of-bound values + └─ Serves remote config at /config +``` + +- **No secret/token-based client auth** — the ingest server holds the InfluxDB token server-side. Clients must send a hashed peer ID via `X-Peer-ID` header. +- **InfluxDB is not exposed** — only accessible within the docker network +- Source: `ingest/main.go` + +## Metrics Collected + +### Connection Stage Timing + +Measurement: `netbird_peer_connection` + +| Field | Timestamps | Description | +|-------|-----------|-------------| +| `signaling_to_connection_seconds` | `SignalingReceived → ConnectionReady` | ICE/relay negotiation time after the first signal is received from the remote peer | +| `connection_to_wg_handshake_seconds` | `ConnectionReady → WgHandshakeSuccess` | WireGuard cryptographic handshake latency once the transport layer is ready | +| `total_seconds` | `SignalingReceived → WgHandshakeSuccess` | End-to-end connection time anchored at the first received signal | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `connection_type`: "ice" | "relay" +- `attempt_type`: "initial" | "reconnection" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +**Note:** `SignalingReceived` is set when the first offer or answer arrives from the remote peer (in both initial and reconnection paths). It excludes the potentially unbounded wait for the remote peer to come online. + +### Sync Duration + +Measurement: `netbird_sync` + +| Field | Description | +|-------|-------------| +| `duration_seconds` | Time to process a sync message from management server | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +### Login Duration + +Measurement: `netbird_login` + +| Field | Description | +|-------|-------------| +| `duration_seconds` | Time to complete the login/auth exchange with management server | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `result`: "success" | "failure" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +## Buffer Limits + +The InfluxDB backend limits in-memory sample storage to prevent unbounded growth when pushes fail: +- **Max age:** Samples older than 5 days are dropped +- **Max size:** Estimated buffer size capped at 5 MB (~20k samples) + +## Configuration + +### Client Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `NB_METRICS_PUSH_ENABLED` | `false` | Enable metrics push to backend | +| `NB_METRICS_SERVER_URL` | *(from remote config)* | Ingest server URL (e.g., `https://ingest.netbird.io`) | +| `NB_METRICS_INTERVAL` | *(from remote config)* | Push interval (e.g., "1m", "30m", "4h") | +| `NB_METRICS_FORCE_SENDING` | `false` | Skip remote config, push unconditionally | +| `NB_METRICS_CONFIG_URL` | `https://ingest.netbird.io/config` | Remote push config URL | + +`NB_METRICS_SERVER_URL` and `NB_METRICS_INTERVAL` override their respective values but do not bypass remote config eligibility checks (version range). Use `NB_METRICS_FORCE_SENDING=true` to skip all remote config gating. + +### Ingest Server Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `INGEST_LISTEN_ADDR` | `:8087` | Listen address | +| `INFLUXDB_URL` | `http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns` | InfluxDB write endpoint | +| `INFLUXDB_TOKEN` | *(required)* | InfluxDB auth token (server-side only) | +| `CONFIG_METRICS_SERVER_URL` | *(empty — disables /config)* | `server_url` in the remote config JSON (the URL clients push metrics to) | +| `CONFIG_VERSION_SINCE` | `0.0.0` | Minimum client version to push metrics | +| `CONFIG_VERSION_UNTIL` | `99.99.99` | Maximum client version to push metrics | +| `CONFIG_PERIOD_MINUTES` | `5` | Push interval in minutes | + +The ingest server serves a remote config JSON at `GET /config` when `CONFIG_METRICS_SERVER_URL` is set. Clients can use `NB_METRICS_CONFIG_URL=http:///config` to fetch it. + +### Configuration Precedence + +For URL and Interval, the precedence is: +1. **Environment variable** - `NB_METRICS_SERVER_URL` / `NB_METRICS_INTERVAL` +2. **Remote config** - fetched from `NB_METRICS_CONFIG_URL` +3. **Default** - 5 minute interval, URL from remote config + +## Push Behavior + +1. `StartPush()` spawns background goroutine with timer +2. First push happens immediately on startup +3. Periodically: `push()` → `Export()` → HTTP POST to ingest server +4. On failure: log error, continue (non-blocking) +5. On success: `Reset()` clears pushed samples +6. `StopPush()` cancels context and waits for goroutine + +Samples are collected with exact timestamps, pushed once, then cleared. No data is resent. + +## Local Development Setup + +### 1. Configure and Start Services + +```bash +# From this directory (client/internal/metrics/infra) +cp .env.example .env +# Edit .env to set INFLUXDB_ADMIN_PASSWORD, INFLUXDB_ADMIN_TOKEN, and GRAFANA_ADMIN_PASSWORD +docker compose up -d +``` + +This starts: +- **Ingest server** on http://localhost:8087 — accepts client metrics (requires `X-Peer-ID` header, no secret/token auth) +- **InfluxDB** — internal only, not exposed to host +- **Grafana** on http://localhost:3001 + +### 2. Configure Client + +```bash +export NB_METRICS_PUSH_ENABLED=true +export NB_METRICS_FORCE_SENDING=true +export NB_METRICS_SERVER_URL=http://localhost:8087 +export NB_METRICS_INTERVAL=1m +``` + +### 3. Run Client + +```bash +cd ../../../.. +go run ./client/ up +``` + +### 4. View in Grafana + +- **InfluxDB dashboard:** http://localhost:3001/d/netbird-influxdb-metrics + +### 5. Verify Data + +```bash +# Query via InfluxDB (using admin token from .env) +docker compose exec influxdb influx query \ + 'from(bucket: "metrics") |> range(start: -1h)' \ + --org netbird + +# Check ingest server health +curl http://localhost:8087/health +``` \ No newline at end of file diff --git a/client/internal/metrics/infra/docker-compose.yml b/client/internal/metrics/infra/docker-compose.yml new file mode 100644 index 000000000..0f2b6b889 --- /dev/null +++ b/client/internal/metrics/infra/docker-compose.yml @@ -0,0 +1,69 @@ +version: '3.8' + +services: + ingest: + container_name: ingest + build: + context: ./ingest + ports: + - "8087:8087" + environment: + - INGEST_LISTEN_ADDR=:8087 + - INFLUXDB_URL=http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns + - INFLUXDB_TOKEN=${INFLUXDB_ADMIN_TOKEN:?required} + - CONFIG_METRICS_SERVER_URL=${CONFIG_METRICS_SERVER_URL:-} + - CONFIG_VERSION_SINCE=${CONFIG_VERSION_SINCE:-0.0.0} + - CONFIG_VERSION_UNTIL=${CONFIG_VERSION_UNTIL:-99.99.99} + - CONFIG_PERIOD_MINUTES=${CONFIG_PERIOD_MINUTES:-5} + depends_on: + - influxdb + restart: unless-stopped + networks: + - metrics + + influxdb: + container_name: influxdb + image: influxdb:2 + # No ports exposed — only accessible within the metrics network + volumes: + - influxdb-data:/var/lib/influxdb2 + - ./influxdb/scripts:/docker-entrypoint-initdb.d + environment: + - DOCKER_INFLUXDB_INIT_MODE=setup + - DOCKER_INFLUXDB_INIT_USERNAME=admin + - DOCKER_INFLUXDB_INIT_PASSWORD=${INFLUXDB_ADMIN_PASSWORD:?required} + - DOCKER_INFLUXDB_INIT_ORG=netbird + - DOCKER_INFLUXDB_INIT_BUCKET=metrics + - DOCKER_INFLUXDB_INIT_RETENTION=365d + - DOCKER_INFLUXDB_INIT_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-} + restart: unless-stopped + networks: + - metrics + + grafana: + container_name: grafana + image: grafana/grafana:11.6.0 + ports: + - "3001:3000" + environment: + - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER:-admin} + - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:?required} + - GF_USERS_ALLOW_SIGN_UP=false + - GF_INSTALL_PLUGINS= + - INFLUXDB_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-} + volumes: + - grafana-data:/var/lib/grafana + - ./grafana/provisioning:/etc/grafana/provisioning + depends_on: + - influxdb + restart: unless-stopped + networks: + - metrics + +volumes: + influxdb-data: + grafana-data: + +networks: + metrics: + driver: bridge diff --git a/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml b/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml new file mode 100644 index 000000000..a7e8d3989 --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml @@ -0,0 +1,12 @@ +apiVersion: 1 + +providers: + - name: 'NetBird Dashboards' + orgId: 1 + folder: '' + type: file + disableDeletion: false + updateIntervalSeconds: 10 + allowUiUpdates: true + options: + path: /etc/grafana/provisioning/dashboards/json \ No newline at end of file diff --git a/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json b/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json new file mode 100644 index 000000000..2bcc9cbab --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json @@ -0,0 +1,280 @@ +{ + "uid": "netbird-influxdb-metrics", + "title": "NetBird Client Metrics (InfluxDB)", + "tags": ["netbird", "connections", "influxdb"], + "timezone": "browser", + "panels": [ + { + "id": 5, + "title": "Sync Duration Extremes", + "type": "stat", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "colorMode": "value", + "graphMode": "none", + "textMode": "auto" + } + }, + { + "id": 6, + "title": "Total Connection Time Extremes", + "type": "stat", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "colorMode": "value", + "graphMode": "none", + "textMode": "auto" + } + }, + { + "id": 1, + "title": "Sync Duration", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Sync Duration\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 4, + "title": "ICE vs Relay", + "type": "piechart", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> drop(columns: [\"deployment_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"connection_pair_id\"])\n |> last()\n |> group(columns: [\"connection_type\"])\n |> count()", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "pieType": "donut", + "tooltip": { + "mode": "multi" + } + } + }, + { + "id": 2, + "title": "Connection Stage Durations (avg)", + "type": "bargauge", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"signaling_to_connection_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Signaling to Connection\"})", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"connection_to_wg_handshake_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Connection to WG Handshake\"})", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "orientation": "horizontal", + "displayMode": "gradient" + } + }, + { + "id": 3, + "title": "Total Connection Time", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> set(key: \"_field\", value: \"Total Connection Time\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 7, + "title": "Login Duration", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Login Duration\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 8, + "title": "Login Success vs Failure", + "type": "piechart", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"result\"])\n |> count()", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "pieType": "donut", + "tooltip": { + "mode": "multi" + } + } + } + ], + "schemaVersion": 27, + "version": 2, + "refresh": "30s" +} diff --git a/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml b/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml new file mode 100644 index 000000000..69b96a93a --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml @@ -0,0 +1,15 @@ +apiVersion: 1 + +datasources: + - name: InfluxDB + uid: influxdb + type: influxdb + access: proxy + url: http://influxdb:8086 + editable: true + jsonData: + version: Flux + organization: netbird + defaultBucket: metrics + secureJsonData: + token: ${INFLUXDB_ADMIN_TOKEN} \ No newline at end of file diff --git a/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh b/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh new file mode 100755 index 000000000..2464803e8 --- /dev/null +++ b/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Creates a scoped InfluxDB read-only token for Grafana. +# Clients do not need a token — they push via the ingest server. + +BUCKET_ID=$(influx bucket list --org netbird --name metrics --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1) +ORG_ID=$(influx org list --name netbird --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1) + +if [[ -z "$BUCKET_ID" ]] || [[ -z "$ORG_ID" ]]; then + echo "ERROR: Could not determine bucket or org ID" >&2 + echo "BUCKET_ID=$BUCKET_ID ORG_ID=$ORG_ID" >&2 + exit 1 +fi + +# Create read-only token for Grafana +READ_TOKEN=$(influx auth create \ + --org netbird \ + --read-bucket "$BUCKET_ID" \ + --description "Grafana read-only token" \ + --json | grep -oP '"token"\s*:\s*"\K[^"]+' | head -1) + +echo "" +echo "============================================" +echo "GRAFANA READ-ONLY TOKEN:" +echo "$READ_TOKEN" +echo "============================================" \ No newline at end of file diff --git a/client/internal/metrics/infra/ingest/Dockerfile b/client/internal/metrics/infra/ingest/Dockerfile new file mode 100644 index 000000000..3620c524b --- /dev/null +++ b/client/internal/metrics/infra/ingest/Dockerfile @@ -0,0 +1,10 @@ +FROM golang:1.25-alpine AS build +WORKDIR /app +COPY go.mod main.go ./ +RUN CGO_ENABLED=0 go build -o ingest . + +FROM alpine:3.20 +RUN adduser -D -H ingest +COPY --from=build /app/ingest /usr/local/bin/ingest +USER ingest +ENTRYPOINT ["ingest"] \ No newline at end of file diff --git a/client/internal/metrics/infra/ingest/go.mod b/client/internal/metrics/infra/ingest/go.mod new file mode 100644 index 000000000..aaf1ea9da --- /dev/null +++ b/client/internal/metrics/infra/ingest/go.mod @@ -0,0 +1,11 @@ +module github.com/netbirdio/netbird/client/internal/metrics/infra/ingest + +go 1.25 + +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/client/internal/metrics/infra/ingest/go.sum b/client/internal/metrics/infra/ingest/go.sum new file mode 100644 index 000000000..c4c1710c4 --- /dev/null +++ b/client/internal/metrics/infra/ingest/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/client/internal/metrics/infra/ingest/main.go b/client/internal/metrics/infra/ingest/main.go new file mode 100644 index 000000000..a5031a873 --- /dev/null +++ b/client/internal/metrics/infra/ingest/main.go @@ -0,0 +1,355 @@ +package main + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +const ( + defaultListenAddr = ":8087" + defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns" + maxBodySize = 50 * 1024 * 1024 // 50 MB max request body + maxDurationSeconds = 300.0 // reject any duration field > 5 minutes + peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars + maxTagValueLength = 64 // reject tag values longer than this +) + +type measurementSpec struct { + allowedFields map[string]bool + allowedTags map[string]bool +} + +var allowedMeasurements = map[string]measurementSpec{ + "netbird_peer_connection": { + allowedFields: map[string]bool{ + "signaling_to_connection_seconds": true, + "connection_to_wg_handshake_seconds": true, + "total_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "connection_type": true, + "attempt_type": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + "connection_pair_id": true, + }, + }, + "netbird_sync": { + allowedFields: map[string]bool{ + "duration_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + }, + }, + "netbird_login": { + allowedFields: map[string]bool{ + "duration_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "result": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + }, + }, +} + +func main() { + listenAddr := envOr("INGEST_LISTEN_ADDR", defaultListenAddr) + influxURL := envOr("INFLUXDB_URL", defaultInfluxDBURL) + influxToken := os.Getenv("INFLUXDB_TOKEN") + + if influxToken == "" { + log.Fatal("INFLUXDB_TOKEN is required") + } + + client := &http.Client{Timeout: 10 * time.Second} + + http.HandleFunc("/", handleIngest(client, influxURL, influxToken)) + + // Build config JSON once at startup from env vars + configJSON := buildConfigJSON() + if configJSON != nil { + log.Printf("serving remote config at /config") + } + + http.HandleFunc("/config", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if configJSON == nil { + http.Error(w, "config not configured", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(configJSON) //nolint:errcheck + }) + + http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok") //nolint:errcheck + }) + + log.Printf("ingest server listening on %s, forwarding to %s", listenAddr, influxURL) + if err := http.ListenAndServe(listenAddr, nil); err != nil { //nolint:gosec + log.Fatal(err) + } +} + +func handleIngest(client *http.Client, influxURL, influxToken string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := validateAuth(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + body, err := readBody(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if len(body) > maxBodySize { + http.Error(w, "body too large", http.StatusRequestEntityTooLarge) + return + } + + validated, err := validateLineProtocol(body) + if err != nil { + log.Printf("WARN validation failed from %s: %v", r.RemoteAddr, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + forwardToInflux(w, r, client, influxURL, influxToken, validated) + } +} + +func forwardToInflux(w http.ResponseWriter, r *http.Request, client *http.Client, influxURL, influxToken string, body []byte) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, influxURL, bytes.NewReader(body)) + if err != nil { + log.Printf("ERROR create request: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + req.Header.Set("Content-Type", "text/plain; charset=utf-8") + req.Header.Set("Authorization", "Token "+influxToken) + + resp, err := client.Do(req) + if err != nil { + log.Printf("ERROR forward to influxdb: %v", err) + http.Error(w, "upstream error", http.StatusBadGateway) + return + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) //nolint:errcheck +} + +// validateAuth checks that the X-Peer-ID header contains a valid hashed peer ID. +func validateAuth(r *http.Request) error { + peerID := r.Header.Get("X-Peer-ID") + if peerID == "" { + return fmt.Errorf("missing X-Peer-ID header") + } + if len(peerID) != peerIDLength { + return fmt.Errorf("invalid X-Peer-ID header length") + } + if _, err := hex.DecodeString(peerID); err != nil { + return fmt.Errorf("invalid X-Peer-ID header format") + } + return nil +} + +// readBody reads the request body, decompressing gzip if Content-Encoding indicates it. +func readBody(r *http.Request) ([]byte, error) { + reader := io.LimitReader(r.Body, maxBodySize+1) + + if r.Header.Get("Content-Encoding") == "gzip" { + gz, err := gzip.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("invalid gzip: %w", err) + } + defer gz.Close() + reader = io.LimitReader(gz, maxBodySize+1) + } + + return io.ReadAll(reader) +} + +// validateLineProtocol parses InfluxDB line protocol lines, +// whitelists measurements and fields, and checks value bounds. +func validateLineProtocol(body []byte) ([]byte, error) { + lines := strings.Split(strings.TrimSpace(string(body)), "\n") + var valid []string + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + if err := validateLine(line); err != nil { + return nil, err + } + + valid = append(valid, line) + } + + if len(valid) == 0 { + return nil, fmt.Errorf("no valid lines") + } + + return []byte(strings.Join(valid, "\n") + "\n"), nil +} + +func validateLine(line string) error { + // line protocol: measurement,tag=val,tag=val field=val,field=val timestamp + parts := strings.SplitN(line, " ", 3) + if len(parts) < 2 { + return fmt.Errorf("invalid line protocol: %q", truncate(line, 100)) + } + + // parts[0] is "measurement,tag=val,tag=val" + measurementAndTags := strings.Split(parts[0], ",") + measurement := measurementAndTags[0] + + spec, ok := allowedMeasurements[measurement] + if !ok { + return fmt.Errorf("unknown measurement: %q", measurement) + } + + // Validate tags (everything after measurement name in parts[0]) + for _, tagPair := range measurementAndTags[1:] { + if err := validateTag(tagPair, measurement, spec.allowedTags); err != nil { + return err + } + } + + // Validate fields + for _, pair := range strings.Split(parts[1], ",") { + if err := validateField(pair, measurement, spec.allowedFields); err != nil { + return err + } + } + + return nil +} + +func validateTag(pair, measurement string, allowedTags map[string]bool) error { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("invalid tag: %q", pair) + } + + tagName := kv[0] + if !allowedTags[tagName] { + return fmt.Errorf("unknown tag %q in measurement %q", tagName, measurement) + } + + if len(kv[1]) > maxTagValueLength { + return fmt.Errorf("tag value too long for %q: %d > %d", tagName, len(kv[1]), maxTagValueLength) + } + + return nil +} + +func validateField(pair, measurement string, allowedFields map[string]bool) error { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("invalid field: %q", pair) + } + + fieldName := kv[0] + if !allowedFields[fieldName] { + return fmt.Errorf("unknown field %q in measurement %q", fieldName, measurement) + } + + val, err := strconv.ParseFloat(kv[1], 64) + if err != nil { + return fmt.Errorf("invalid field value %q for %q", kv[1], fieldName) + } + if val < 0 { + return fmt.Errorf("negative value for %q: %g", fieldName, val) + } + if strings.HasSuffix(fieldName, "_seconds") && val > maxDurationSeconds { + return fmt.Errorf("%q too large: %g > %g", fieldName, val, maxDurationSeconds) + } + + return nil +} + +// buildConfigJSON builds the remote config JSON from env vars. +// Returns nil if required vars are not set. +func buildConfigJSON() []byte { + serverURL := os.Getenv("CONFIG_METRICS_SERVER_URL") + versionSince := envOr("CONFIG_VERSION_SINCE", "0.0.0") + versionUntil := envOr("CONFIG_VERSION_UNTIL", "99.99.99") + periodMinutes := envOr("CONFIG_PERIOD_MINUTES", "5") + + if serverURL == "" { + return nil + } + + period, err := strconv.Atoi(periodMinutes) + if err != nil || period <= 0 { + log.Printf("WARN invalid CONFIG_PERIOD_MINUTES: %q, using 5", periodMinutes) + period = 5 + } + + cfg := map[string]any{ + "server_url": serverURL, + "version-since": versionSince, + "version-until": versionUntil, + "period_minutes": period, + } + + data, err := json.Marshal(cfg) + if err != nil { + log.Printf("ERROR failed to marshal config: %v", err) + return nil + } + return data +} + +func envOr(key, defaultVal string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultVal +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} diff --git a/client/internal/metrics/infra/ingest/main_test.go b/client/internal/metrics/infra/ingest/main_test.go new file mode 100644 index 000000000..bacaa4588 --- /dev/null +++ b/client/internal/metrics/infra/ingest/main_test.go @@ -0,0 +1,124 @@ +package main + +import ( + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateLine_ValidPeerConnection(t *testing.T) { + line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789,connection_pair_id=pair1234 signaling_to_connection_seconds=1.5,connection_to_wg_handshake_seconds=0.5,total_seconds=2 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_ValidSync(t *testing.T) { + line := `netbird_sync,deployment_type=selfhosted,version=2.0.0,os=darwin,arch=arm64,peer_id=abcdef0123456789 duration_seconds=1.5 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_ValidLogin(t *testing.T) { + line := `netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789 duration_seconds=3.2 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_UnknownMeasurement(t *testing.T) { + line := `unknown_metric,foo=bar value=1 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown measurement") +} + +func TestValidateLine_UnknownTag(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,evil_tag=injected,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown tag") +} + +func TestValidateLine_UnknownField(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc injected_field=1 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown field") +} + +func TestValidateLine_NegativeValue(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=-1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "negative") +} + +func TestValidateLine_DurationTooLarge(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "too large") +} + +func TestValidateLine_TotalSecondsTooLarge(t *testing.T) { + line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "too large") +} + +func TestValidateLine_TagValueTooLong(t *testing.T) { + longTag := strings.Repeat("a", maxTagValueLength+1) + line := `netbird_sync,deployment_type=` + longTag + `,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "tag value too long") +} + +func TestValidateLineProtocol_MultipleLines(t *testing.T) { + body := []byte( + "netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" + + "netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=2.0 1234567890\n", + ) + validated, err := validateLineProtocol(body) + require.NoError(t, err) + assert.Contains(t, string(validated), "netbird_sync") + assert.Contains(t, string(validated), "netbird_login") +} + +func TestValidateLineProtocol_RejectsOnBadLine(t *testing.T) { + body := []byte( + "netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" + + "evil_metric,foo=bar value=1 1234567890\n", + ) + _, err := validateLineProtocol(body) + require.Error(t, err) +} + +func TestValidateAuth(t *testing.T) { + tests := []struct { + name string + peerID string + wantErr bool + }{ + {"valid hex", "abcdef0123456789", false}, + {"empty", "", true}, + {"too short", "abcdef01234567", true}, + {"too long", "abcdef01234567890", true}, + {"invalid hex", "ghijklmnopqrstuv", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/", nil) + if tt.peerID != "" { + r.Header.Set("X-Peer-ID", tt.peerID) + } + err := validateAuth(r) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/client/internal/metrics/metrics.go b/client/internal/metrics/metrics.go new file mode 100644 index 000000000..4ebb43496 --- /dev/null +++ b/client/internal/metrics/metrics.go @@ -0,0 +1,224 @@ +package metrics + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +// AgentInfo holds static information about the agent +type AgentInfo struct { + DeploymentType DeploymentType + Version string + OS string // runtime.GOOS (linux, darwin, windows, etc.) + Arch string // runtime.GOARCH (amd64, arm64, etc.) + peerID string // anonymised peer identifier (SHA-256 of WireGuard public key) +} + +// peerIDFromPublicKey returns a truncated SHA-256 hash (8 bytes / 16 hex chars) of the given WireGuard public key. +func peerIDFromPublicKey(pubKey string) string { + hash := sha256.Sum256([]byte(pubKey)) + return hex.EncodeToString(hash[:8]) +} + +// connectionPairID returns a deterministic identifier for a connection between two peers. +// It sorts the two peer IDs before hashing so the same pair always produces the same ID +// regardless of which side computes it. +func connectionPairID(peerID1, peerID2 string) string { + a, b := peerID1, peerID2 + if a > b { + a, b = b, a + } + hash := sha256.Sum256([]byte(a + b)) + return hex.EncodeToString(hash[:8]) +} + +// metricsImplementation defines the internal interface for metrics implementations +type metricsImplementation interface { + // RecordConnectionStages records connection stage metrics from timestamps + RecordConnectionStages( + ctx context.Context, + agentInfo AgentInfo, + connectionPairID string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, + ) + + // RecordSyncDuration records how long it took to process a sync message + RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration) + + // RecordLoginDuration records how long the login to management took + RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool) + + // Export exports metrics in InfluxDB line protocol format + Export(w io.Writer) error + + // Reset clears all collected metrics + Reset() +} + +type ClientMetrics struct { + impl metricsImplementation + + agentInfo AgentInfo + mu sync.RWMutex + + push *Push + pushMu sync.Mutex + wg sync.WaitGroup + pushCancel context.CancelFunc +} + +// ConnectionStageTimestamps holds timestamps for each connection stage +type ConnectionStageTimestamps struct { + SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection) + ConnectionReady time.Time + WgHandshakeSuccess time.Time +} + +// String returns a human-readable representation of the connection stage timestamps +func (c ConnectionStageTimestamps) String() string { + return fmt.Sprintf("ConnectionStageTimestamps{SignalingReceived=%v, ConnectionReady=%v, WgHandshakeSuccess=%v}", + c.SignalingReceived.Format(time.RFC3339Nano), + c.ConnectionReady.Format(time.RFC3339Nano), + c.WgHandshakeSuccess.Format(time.RFC3339Nano), + ) +} + +// RecordConnectionStages calculates stage durations from timestamps and records them. +// remotePubKey is the remote peer's WireGuard public key; it will be hashed for anonymisation. +func (c *ClientMetrics) RecordConnectionStages( + ctx context.Context, + remotePubKey string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, +) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + remotePeerID := peerIDFromPublicKey(remotePubKey) + pairID := connectionPairID(agentInfo.peerID, remotePeerID) + c.impl.RecordConnectionStages(ctx, agentInfo, pairID, connectionType, isReconnection, timestamps) +} + +// RecordSyncDuration records the duration of sync message processing +func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + c.impl.RecordSyncDuration(ctx, agentInfo, duration) +} + +// RecordLoginDuration records how long the login to management server took +func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + c.impl.RecordLoginDuration(ctx, agentInfo, duration, success) +} + +// UpdateAgentInfo updates the agent information (e.g., when switching profiles). +// publicKey is the WireGuard public key; it will be hashed for anonymisation. +func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) { + if c == nil { + return + } + + agentInfo.peerID = peerIDFromPublicKey(publicKey) + + c.mu.Lock() + c.agentInfo = agentInfo + c.mu.Unlock() + + c.pushMu.Lock() + push := c.push + c.pushMu.Unlock() + if push != nil { + push.SetPeerID(agentInfo.peerID) + } +} + +// Export exports metrics to the writer +func (c *ClientMetrics) Export(w io.Writer) error { + if c == nil { + return nil + } + + return c.impl.Export(w) +} + +// StartPush starts periodic pushing of metrics with the given configuration +// Precedence: PushConfig.ServerAddress > remote config server_url +func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) { + if c == nil { + return + } + + c.pushMu.Lock() + defer c.pushMu.Unlock() + + if c.push != nil { + log.Warnf("metrics push already running") + return + } + + c.mu.RLock() + agentVersion := c.agentInfo.Version + peerID := c.agentInfo.peerID + c.mu.RUnlock() + + configManager := remoteconfig.NewManager(getMetricsConfigURL(), remoteconfig.DefaultMinRefreshInterval) + push, err := NewPush(c.impl, configManager, config, agentVersion) + if err != nil { + log.Errorf("failed to create metrics push: %v", err) + return + } + push.SetPeerID(peerID) + + ctx, cancel := context.WithCancel(ctx) + c.pushCancel = cancel + + c.wg.Add(1) + go func() { + defer c.wg.Done() + push.Start(ctx) + }() + c.push = push +} + +func (c *ClientMetrics) StopPush() { + if c == nil { + return + } + c.pushMu.Lock() + defer c.pushMu.Unlock() + if c.push == nil { + return + } + + c.pushCancel() + c.wg.Wait() + c.push = nil +} diff --git a/client/internal/metrics/metrics_default.go b/client/internal/metrics/metrics_default.go new file mode 100644 index 000000000..927ab51d1 --- /dev/null +++ b/client/internal/metrics/metrics_default.go @@ -0,0 +1,11 @@ +//go:build !js + +package metrics + +// NewClientMetrics creates a new ClientMetrics instance +func NewClientMetrics(agentInfo AgentInfo) *ClientMetrics { + return &ClientMetrics{ + impl: newInfluxDBMetrics(), + agentInfo: agentInfo, + } +} diff --git a/client/internal/metrics/metrics_js.go b/client/internal/metrics/metrics_js.go new file mode 100644 index 000000000..dfa6d8243 --- /dev/null +++ b/client/internal/metrics/metrics_js.go @@ -0,0 +1,8 @@ +//go:build js + +package metrics + +// NewClientMetrics returns nil on WASM builds — all ClientMetrics methods are nil-safe. +func NewClientMetrics(AgentInfo) *ClientMetrics { + return nil +} diff --git a/client/internal/metrics/push.go b/client/internal/metrics/push.go new file mode 100644 index 000000000..ee0508f36 --- /dev/null +++ b/client/internal/metrics/push.go @@ -0,0 +1,289 @@ +package metrics + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +const ( + // defaultPushInterval is the default interval for pushing metrics + defaultPushInterval = 5 * time.Minute +) + +// defaultMetricsServerURL is used as fallback when NB_METRICS_FORCE_SENDING is true +var defaultMetricsServerURL *url.URL + +func init() { + defaultMetricsServerURL, _ = url.Parse("https://ingest.netbird.io") +} + +// PushConfig holds configuration for metrics push +type PushConfig struct { + // ServerAddress is the metrics server URL. If nil, uses remote config server_url. + ServerAddress *url.URL + // Interval is how often to push metrics. If 0, uses remote config interval or defaultPushInterval. + Interval time.Duration + // ForceSending skips remote configuration fetch and version checks, pushing unconditionally. + ForceSending bool +} + +// PushConfigFromEnv builds a PushConfig from environment variables. +func PushConfigFromEnv() PushConfig { + config := PushConfig{} + + config.ForceSending = isForceSending() + config.ServerAddress = getMetricsServerURL() + config.Interval = getMetricsInterval() + + return config +} + +// remoteConfigProvider abstracts remote push config fetching for testability +type remoteConfigProvider interface { + RefreshIfNeeded(ctx context.Context) *remoteconfig.Config +} + +// Push handles periodic pushing of metrics +type Push struct { + metrics metricsImplementation + configManager remoteConfigProvider + agentVersion *goversion.Version + + peerID string + peerMu sync.RWMutex + + client *http.Client + cfgForceSending bool + cfgInterval time.Duration + cfgAddress *url.URL +} + +// NewPush creates a new Push instance with configuration resolution +func NewPush(metrics metricsImplementation, configManager remoteConfigProvider, config PushConfig, agentVersion string) (*Push, error) { + var cfgInterval time.Duration + var cfgAddress *url.URL + + if config.ForceSending { + cfgInterval = config.Interval + if config.Interval <= 0 { + cfgInterval = defaultPushInterval + } + + cfgAddress = config.ServerAddress + if cfgAddress == nil { + cfgAddress = defaultMetricsServerURL + } + } else { + cfgAddress = config.ServerAddress + + if config.Interval < 0 { + log.Warnf("negative metrics push interval %s", config.Interval) + } else { + cfgInterval = config.Interval + } + } + + parsedVersion, err := goversion.NewVersion(agentVersion) + if err != nil { + if !config.ForceSending { + return nil, fmt.Errorf("parse agent version %q: %w", agentVersion, err) + } + } + + return &Push{ + metrics: metrics, + configManager: configManager, + agentVersion: parsedVersion, + cfgForceSending: config.ForceSending, + cfgInterval: cfgInterval, + cfgAddress: cfgAddress, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + }, nil +} + +// SetPeerID updates the hashed peer ID used for the Authorization header. +func (p *Push) SetPeerID(peerID string) { + p.peerMu.Lock() + p.peerID = peerID + p.peerMu.Unlock() +} + +// Start starts the periodic push loop. +// The env interval override controls tick frequency but does not bypass remote config +// version gating. Use ForceSending to skip remote config entirely. +func (p *Push) Start(ctx context.Context) { + // Log initial state + switch { + case p.cfgForceSending: + log.Infof("started metrics push with force sending to %s, interval %s", p.cfgAddress, p.cfgInterval) + case p.cfgAddress != nil: + log.Infof("started metrics push with server URL override: %s", p.cfgAddress.String()) + default: + log.Infof("started metrics push, server URL will be resolved from remote config") + } + + timer := time.NewTimer(0) // fire immediately on first iteration + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + log.Debug("stopping metrics push") + return + case <-timer.C: + } + + pushURL, interval := p.resolve(ctx) + if pushURL != "" { + if err := p.push(ctx, pushURL); err != nil { + log.Errorf("failed to push metrics: %v", err) + } + } + + if interval <= 0 { + interval = defaultPushInterval + } + timer.Reset(interval) + } +} + +// resolve returns the push URL and interval for the next cycle. +// Returns empty pushURL to skip this cycle. +func (p *Push) resolve(ctx context.Context) (pushURL string, interval time.Duration) { + if p.cfgForceSending { + return p.resolveServerURL(nil), p.cfgInterval + } + + config := p.configManager.RefreshIfNeeded(ctx) + if config == nil { + log.Debug("no metrics push config available, waiting to retry") + return "", defaultPushInterval + } + + // prefer env variables instead of remote config + if p.cfgInterval > 0 { + interval = p.cfgInterval + } else { + interval = config.Interval + } + + if !isVersionInRange(p.agentVersion, config.VersionSince, config.VersionUntil) { + log.Debugf("agent version %s not in range [%s, %s), skipping metrics push", + p.agentVersion, config.VersionSince, config.VersionUntil) + return "", interval + } + + pushURL = p.resolveServerURL(&config.ServerURL) + if pushURL == "" { + log.Warn("no metrics server URL available, skipping push") + } + return pushURL, interval +} + +// push exports metrics and sends them to the metrics server +func (p *Push) push(ctx context.Context, pushURL string) error { + // Export metrics without clearing + var buf bytes.Buffer + if err := p.metrics.Export(&buf); err != nil { + return fmt.Errorf("export metrics: %w", err) + } + + // Don't push if there are no metrics + if buf.Len() == 0 { + log.Tracef("no metrics to push") + return nil + } + + // Gzip compress the body + compressed, err := gzipCompress(buf.Bytes()) + if err != nil { + return fmt.Errorf("gzip compress: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", pushURL, compressed) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "text/plain; charset=utf-8") + req.Header.Set("Content-Encoding", "gzip") + + p.peerMu.RLock() + peerID := p.peerID + p.peerMu.RUnlock() + if peerID != "" { + req.Header.Set("X-Peer-ID", peerID) + } + + // Send request + resp, err := p.client.Do(req) + if err != nil { + return fmt.Errorf("send request: %w", err) + } + defer func() { + if resp.Body == nil { + return + } + if err := resp.Body.Close(); err != nil { + log.Warnf("failed to close response body: %v", err) + } + }() + + // Check response status + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("push failed with status %d", resp.StatusCode) + } + + log.Debugf("successfully pushed metrics to %s", pushURL) + p.metrics.Reset() + return nil +} + +// resolveServerURL determines the push URL. +// Precedence: envAddress (env var) > remote config server_url +func (p *Push) resolveServerURL(remoteServerURL *url.URL) string { + var baseURL *url.URL + if p.cfgAddress != nil { + baseURL = p.cfgAddress + } else { + baseURL = remoteServerURL + } + + if baseURL == nil { + return "" + } + + return baseURL.String() +} + +// gzipCompress compresses data using gzip and returns the compressed buffer. +func gzipCompress(data []byte) (*bytes.Buffer, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write(data); err != nil { + _ = gz.Close() + return nil, err + } + if err := gz.Close(); err != nil { + return nil, err + } + return &buf, nil +} + +// isVersionInRange checks if current falls within [since, until) +func isVersionInRange(current, since, until *goversion.Version) bool { + return !current.LessThan(since) && current.LessThan(until) +} diff --git a/client/internal/metrics/push_test.go b/client/internal/metrics/push_test.go new file mode 100644 index 000000000..20a509da1 --- /dev/null +++ b/client/internal/metrics/push_test.go @@ -0,0 +1,343 @@ +package metrics + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + goversion "github.com/hashicorp/go-version" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +func mustVersion(s string) *goversion.Version { + v, err := goversion.NewVersion(s) + if err != nil { + panic(err) + } + return v +} + +func mustURL(s string) url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return *u +} + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +func testConfig(serverURL, since, until string, period time.Duration) *remoteconfig.Config { + return &remoteconfig.Config{ + ServerURL: mustURL(serverURL), + VersionSince: mustVersion(since), + VersionUntil: mustVersion(until), + Interval: period, + } +} + +// mockConfigProvider implements remoteConfigProvider for testing +type mockConfigProvider struct { + config *remoteconfig.Config +} + +func (m *mockConfigProvider) RefreshIfNeeded(_ context.Context) *remoteconfig.Config { + return m.config +} + +// mockMetrics implements metricsImplementation for testing +type mockMetrics struct { + exportData string +} + +func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ string, _ ConnectionType, _ bool, _ ConnectionStageTimestamps) { +} + +func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) { +} + +func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) { +} + +func (m *mockMetrics) Export(w io.Writer) error { + if m.exportData != "" { + _, err := w.Write([]byte(m.exportData)) + return err + } + return nil +} + +func (m *mockMetrics) Reset() { +} + +func TestPush_OverrideIntervalPushes(t *testing.T) { + var pushCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pushCount.Add(1) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 50 * time.Millisecond, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + push.Start(ctx) + close(done) + }() + + require.Eventually(t, func() bool { + return pushCount.Load() >= 3 + }, 2*time.Second, 10*time.Millisecond) + + cancel() + <-done +} + +func TestPush_RemoteConfigVersionInRange(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_RemoteConfigVersionOutOfRange(t *testing.T) { + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig("http://localhost", "1.0.0", "1.5.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "2.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_NoConfigReturnsDefault(t *testing.T) { + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) +} + +func TestPush_OverrideIntervalRespectsVersionCheck(t *testing.T) { + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: testConfig("http://localhost", "3.0.0", "4.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + ServerAddress: parseURL("http://localhost"), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) // version out of range + assert.Equal(t, 30*time.Second, interval) // but uses override interval +} + +func TestPush_OverrideIntervalUsedWhenVersionInRange(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 30*time.Second, interval) +} + +func TestPush_NoMetricsSkipsPush(t *testing.T) { + var pushCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pushCount.Add(1) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: ""} // no metrics to export + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0") + require.NoError(t, err) + + err = push.push(context.Background(), server.URL) + assert.NoError(t, err) + assert.Equal(t, int32(0), pushCount.Load()) +} + +func TestPush_ServerURLFromRemoteConfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Contains(t, pushURL, server.URL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_ServerAddressOverridesTakePrecedenceOverRemoteConfig(t *testing.T) { + overrideServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer overrideServer.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig("http://remote-config-server", "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ServerAddress: parseURL(overrideServer.URL), + }, "1.5.0") + require.NoError(t, err) + + pushURL, _ := push.resolve(context.Background()) + assert.Contains(t, pushURL, overrideServer.URL) + assert.NotContains(t, pushURL, "remote-config-server") +} + +func TestPush_OverrideIntervalWithoutOverrideURL_UsesRemoteConfigURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Contains(t, pushURL, server.URL) + assert.Equal(t, 30*time.Second, interval) +} + +func TestPush_NoConfigSkipsPush(t *testing.T) { + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) // no config available, use default retry interval +} + +func TestPush_ForceSendingSkipsRemoteConfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ForceSending: true, + Interval: 1 * time.Minute, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_ForceSendingUsesDefaultInterval(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ForceSending: true, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) +} + +func TestIsVersionInRange(t *testing.T) { + tests := []struct { + name string + current string + since string + until string + expected bool + }{ + {"at lower bound inclusive", "1.2.2", "1.2.2", "1.2.3", true}, + {"in range", "1.2.2", "1.2.0", "1.3.0", true}, + {"at upper bound exclusive", "1.2.3", "1.2.2", "1.2.3", false}, + {"below range", "1.2.1", "1.2.2", "1.2.3", false}, + {"above range", "1.3.0", "1.2.2", "1.2.3", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isVersionInRange(mustVersion(tt.current), mustVersion(tt.since), mustVersion(tt.until))) + }) + } +} diff --git a/client/internal/metrics/remoteconfig/manager.go b/client/internal/metrics/remoteconfig/manager.go new file mode 100644 index 000000000..01c37891f --- /dev/null +++ b/client/internal/metrics/remoteconfig/manager.go @@ -0,0 +1,149 @@ +package remoteconfig + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultMinRefreshInterval = 30 * time.Minute +) + +// Config holds the parsed remote push configuration +type Config struct { + ServerURL url.URL + VersionSince *goversion.Version + VersionUntil *goversion.Version + Interval time.Duration +} + +// rawConfig is the JSON wire format fetched from the remote server +type rawConfig struct { + ServerURL string `json:"server_url"` + VersionSince string `json:"version-since"` + VersionUntil string `json:"version-until"` + PeriodMinutes int `json:"period_minutes"` +} + +// Manager handles fetching and caching remote push configuration +type Manager struct { + configURL string + minRefreshInterval time.Duration + client *http.Client + + mu sync.Mutex + lastConfig *Config + lastFetched time.Time +} + +func NewManager(configURL string, minRefreshInterval time.Duration) *Manager { + return &Manager{ + configURL: configURL, + minRefreshInterval: minRefreshInterval, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// RefreshIfNeeded fetches new config if the cached one is stale. +// Returns the current config (possibly just fetched) or nil if unavailable. +func (m *Manager) RefreshIfNeeded(ctx context.Context) *Config { + m.mu.Lock() + defer m.mu.Unlock() + + if m.isConfigFresh() { + return m.lastConfig + } + + fetchedConfig, err := m.fetch(ctx) + m.lastFetched = time.Now() + if err != nil { + log.Warnf("failed to fetch metrics remote config: %v", err) + return m.lastConfig // return cached (may be nil) + } + + m.lastConfig = fetchedConfig + + log.Tracef("fetched metrics remote config: version-since=%s version-until=%s period=%s", + fetchedConfig.VersionSince, fetchedConfig.VersionUntil, fetchedConfig.Interval) + + return fetchedConfig +} + +func (m *Manager) isConfigFresh() bool { + if m.lastConfig == nil { + return false + } + return time.Since(m.lastFetched) < m.minRefreshInterval +} + +func (m *Manager) fetch(ctx context.Context) (*Config, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.configURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + resp, err := m.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer func() { + if resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + + var raw rawConfig + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if raw.PeriodMinutes <= 0 { + return nil, fmt.Errorf("invalid period_minutes: %d", raw.PeriodMinutes) + } + + if raw.ServerURL == "" { + return nil, fmt.Errorf("server_url is required") + } + + serverURL, err := url.Parse(raw.ServerURL) + if err != nil { + return nil, fmt.Errorf("parse server_url %q: %w", raw.ServerURL, err) + } + + since, err := goversion.NewVersion(raw.VersionSince) + if err != nil { + return nil, fmt.Errorf("parse version-since %q: %w", raw.VersionSince, err) + } + + until, err := goversion.NewVersion(raw.VersionUntil) + if err != nil { + return nil, fmt.Errorf("parse version-until %q: %w", raw.VersionUntil, err) + } + + return &Config{ + ServerURL: *serverURL, + VersionSince: since, + VersionUntil: until, + Interval: time.Duration(raw.PeriodMinutes) * time.Minute, + }, nil +} diff --git a/client/internal/metrics/remoteconfig/manager_test.go b/client/internal/metrics/remoteconfig/manager_test.go new file mode 100644 index 000000000..68ca3b4c4 --- /dev/null +++ b/client/internal/metrics/remoteconfig/manager_test.go @@ -0,0 +1,197 @@ +package remoteconfig + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testMinRefresh = 100 * time.Millisecond + +func TestManager_FetchSuccess(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + + require.NotNil(t, config) + assert.Equal(t, "https://ingest.example.com", config.ServerURL.String()) + assert.Equal(t, "1.0.0", config.VersionSince.String()) + assert.Equal(t, "2.0.0", config.VersionUntil.String()) + assert.Equal(t, 60*time.Minute, config.Interval) +} + +func TestManager_CachesConfig(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First call fetches + config1 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config1) + assert.Equal(t, int32(1), fetchCount.Load()) + + // Second call uses cache (within minRefreshInterval) + config2 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config2) + assert.Equal(t, int32(1), fetchCount.Load()) + assert.Equal(t, config1, config2) +} + +func TestManager_RefetchesWhenStale(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First fetch + mgr.RefreshIfNeeded(context.Background()) + assert.Equal(t, int32(1), fetchCount.Load()) + + // Wait for config to become stale + time.Sleep(testMinRefresh + 10*time.Millisecond) + + // Should refetch + mgr.RefreshIfNeeded(context.Background()) + assert.Equal(t, int32(2), fetchCount.Load()) +} + +func TestManager_FetchFailureReturnsNil(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + + assert.Nil(t, config) +} + +func TestManager_FetchFailureReturnsCached(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + if fetchCount.Load() > 1 { + w.WriteHeader(http.StatusInternalServerError) + return + } + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First call succeeds + config1 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config1) + + // Wait for config to become stale + time.Sleep(testMinRefresh + 10*time.Millisecond) + + // Second call fails but returns cached + config2 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config2) + assert.Equal(t, config1, config2) +} + +func TestManager_RejectsInvalidPeriod(t *testing.T) { + tests := []struct { + name string + period int + }{ + {"zero", 0}, + {"negative", -5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: tt.period, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) + }) + } +} + +func TestManager_RejectsEmptyServerURL(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) +} + +func TestManager_RejectsInvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("not json")) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) +} + +func newConfigServer(t *testing.T, config rawConfig) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(config) + require.NoError(t, err) + })) +} diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 7c95e2b99..310d61a25 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -22,4 +22,8 @@ type MobileDependency struct { DnsManager dns.IosDnsManager FileDescriptor int32 StateFilePath string + + // TempDir is a writable directory for temporary files (e.g., debug bundle zip). + // On Android, this should be set to the app's cache directory. + TempDir string } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..2420b1fdf 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -7,7 +7,9 @@ import ( "fmt" "net/netip" "sync" + "time" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" log "github.com/sirupsen/logrus" nfct "github.com/ti-mo/conntrack" @@ -17,31 +19,64 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -const defaultChannelSize = 100 +const ( + defaultChannelSize = 100 + reconnectInitInterval = 5 * time.Second + reconnectMaxInterval = 5 * time.Minute + reconnectRandomization = 0.5 +) + +// listener abstracts a netlink conntrack connection for testability. +type listener interface { + Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) + Close() error +} // ConnTrack manages kernel-based conntrack events type ConnTrack struct { flowLogger nftypes.FlowLogger iface nftypes.IFaceMapper - conn *nfct.Conn + conn listener mux sync.Mutex + dial func() (listener, error) instanceID uuid.UUID started bool done chan struct{} sysctlModified bool } +// DialFunc is a constructor for netlink conntrack connections. +type DialFunc func() (listener, error) + +// Option configures a ConnTrack instance. +type Option func(*ConnTrack) + +// WithDialer overrides the default netlink dialer, primarily for testing. +func WithDialer(dial DialFunc) Option { + return func(c *ConnTrack) { + c.dial = dial + } +} + +func defaultDial() (listener, error) { + return nfct.Dial(nil) +} + // New creates a new connection tracker that interfaces with the kernel's conntrack system -func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { - return &ConnTrack{ +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack { + ct := &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), - started: false, + dial: defaultDial, done: make(chan struct{}, 1), } + for _, opt := range opts { + opt(ct) + } + return ct } // Start begins tracking connections by listening for conntrack events. This method is idempotent. @@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error { c.EnableAccounting() } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { + c.RestoreAccounting() return fmt.Errorf("dial conntrack: %w", err) } c.conn = conn @@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error { log.Errorf("Error closing conntrack connection: %v", err) } c.conn = nil + c.RestoreAccounting() return fmt.Errorf("start conntrack listener: %w", err) } + // Drain any stale stop signal from a previous cycle. + select { + case <-c.done: + default: + } + c.started = true go c.receiverRoutine(events, errChan) @@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) case event := <-events: c.handleEvent(event) case err := <-errChan: - log.Errorf("Error from conntrack event listener: %v", err) - if err := c.conn.Close(); err != nil { - log.Errorf("Error closing conntrack connection: %v", err) + if events, errChan = c.handleListenerError(err); events == nil { + return } - return case <-c.done: return } } } +// handleListenerError closes the failed connection and attempts to reconnect. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) { + log.Warnf("conntrack event listener failed: %v", err) + c.closeConn() + return c.reconnect() +} + +func (c *ConnTrack) closeConn() { + c.mux.Lock() + defer c.mux.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Debugf("close conntrack connection: %v", err) + } + c.conn = nil + } +} + +// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) { + bo := &backoff.ExponentialBackOff{ + InitialInterval: reconnectInitInterval, + RandomizationFactor: reconnectRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: reconnectMaxInterval, + MaxElapsedTime: 0, // retry indefinitely + Clock: backoff.SystemClock, + } + bo.Reset() + + for { + delay := bo.NextBackOff() + log.Infof("reconnecting conntrack listener in %s", delay) + + select { + case <-c.done: + c.mux.Lock() + c.started = false + c.mux.Unlock() + return nil, nil + case <-time.After(delay): + } + + conn, err := c.dial() + if err != nil { + log.Warnf("reconnect conntrack dial: %v", err) + continue + } + + events := make(chan nfct.Event, defaultChannelSize) + errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + }) + if err != nil { + log.Warnf("reconnect conntrack listen: %v", err) + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + continue + } + + c.mux.Lock() + if !c.started { + // Stop() ran while we were reconnecting. + c.mux.Unlock() + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + return nil, nil + } + c.conn = conn + c.mux.Unlock() + + log.Infof("conntrack listener reconnected successfully") + + return events, errChan + } +} + // Stop stops the connection tracking. This method is idempotent. func (c *ConnTrack) Stop() { c.mux.Lock() @@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error { c.mux.Lock() defer c.mux.Unlock() - if c.started { - select { - case c.done <- struct{}{}: - default: - } + if !c.started { + return nil } + select { + case c.done <- struct{}{}: + default: + } + + c.started = false + + var closeErr error if c.conn != nil { - err := c.conn.Close() + closeErr = c.conn.Close() c.conn = nil - c.started = false + } - c.RestoreAccounting() + c.RestoreAccounting() - if err != nil { - return fmt.Errorf("close conntrack: %w", err) - } + if closeErr != nil { + return fmt.Errorf("close conntrack: %w", closeErr) } return nil diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go new file mode 100644 index 000000000..35ceec90d --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -0,0 +1,224 @@ +//go:build linux && !android + +package conntrack + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +type mockListener struct { + errChan chan error + closed atomic.Bool + closedCh chan struct{} +} + +func newMockListener() *mockListener { + return &mockListener{ + errChan: make(chan error, 1), + closedCh: make(chan struct{}), + } +} + +func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) { + return m.errChan, nil +} + +func (m *mockListener) Close() error { + if m.closed.CompareAndSwap(false, true) { + close(m.closedCh) + } + return nil +} + +func TestReconnectAfterError(t *testing.T) { + first := newMockListener() + second := newMockListener() + third := newMockListener() + listeners := []*mockListener{first, second, third} + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := int(callCount.Add(1)) - 1 + return listeners[n], nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Inject an error on the first listener. + first.errChan <- assert.AnError + + // Wait for reconnect to complete. + require.Eventually(t, func() bool { + return callCount.Load() >= 2 + }, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection") + + // The first connection must have been closed. + select { + case <-first.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("first connection was not closed") + } + + // Verify the receiver is still running by injecting and handling a second error. + second.errChan <- assert.AnError + + require.Eventually(t, func() bool { + return callCount.Load() >= 3 + }, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed") + + ct.Stop() +} + +func TestStopDuringReconnectBackoff(t *testing.T) { + mock := newMockListener() + + ct := New(nil, nil, WithDialer(func() (listener, error) { + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger an error so the receiver enters reconnect. + mock.errChan <- assert.AnError + + // Wait for the error handler to close the old listener before calling Stop. + select { + case <-mock.closedCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for reconnect to start") + } + + // Stop while reconnecting. + ct.Stop() + + ct.mux.Lock() + assert.False(t, ct.started, "started should be false after Stop") + assert.Nil(t, ct.conn, "conn should be nil after Stop") + ct.mux.Unlock() +} + +func TestStopRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger error to enter reconnect. + first.errChan <- assert.AnError + + // Wait for reconnect's second dial to begin. + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Stop while dial is in progress (conn is nil at this point). + ct.Stop() + + // Let the dial complete. reconnect should detect started==false and close the new conn. + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Stop") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestCloseRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + first.errChan <- assert.AnError + + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Close while dial is in progress (conn is nil). + require.NoError(t, ct.Close()) + + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Close") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestStartIsIdempotent(t *testing.T) { + mock := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + callCount.Add(1) + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Second Start should be a no-op. + err = ct.Start(false) + require.NoError(t, err) + + assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once") + + ct.Stop() +} diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go index c287236e8..a4a4f76ac 100644 --- a/client/internal/networkmonitor/check_change_common.go +++ b/client/internal/networkmonitor/check_change_common.go @@ -22,51 +22,56 @@ func prepareFd() (int, error) { func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) + // Wait until fd is readable or context is cancelled, to avoid a busy-loop + // when the routing socket returns EAGAIN (e.g. immediately after wakeup). + if err := waitReadable(ctx, fd); err != nil { + return err + } + + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { + continue + } + if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) { + return fmt.Errorf("routing socket closed: %w", err) + } + return fmt.Errorf("read routing socket: %w", err) + } + + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + log.Debugf("Network monitor: error parsing routing message: %v", err) continue } - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + if route.Dst.Bits() != 0 { + continue + } + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } } } } @@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) { return systemops.MsgToRoute(msg) } + +// waitReadable blocks until fd has data to read, or ctx is cancelled. +func waitReadable(ctx context.Context, fd int) error { + var fdset unix.FdSet + if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) { + return fmt.Errorf("fd %d out of range for FdSet", fd) + } + + for { + if err := ctx.Err(); err != nil { + return err + } + + fdset = unix.FdSet{} + fdset.Set(fd) + // Use a 1-second timeout so we can re-check ctx periodically. + tv := unix.Timeval{Sec: 1} + n, err := unix.Select(fd+1, &fdset, nil, nil, &tv) + if err != nil { + if errors.Is(err, unix.EINTR) { + continue + } + return fmt.Errorf("select on routing socket: %w", err) + } + if n > 0 { + return nil + } + // timeout — loop back and re-check ctx + } +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 05a397f3d..8d1585b3f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -3,7 +3,6 @@ package peer import ( "context" "fmt" - "math/rand" "net" "net/netip" "runtime" @@ -16,26 +15,39 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/peer/conntype" "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/worker" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" - semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) +// MetricsRecorder is an interface for recording peer connection metrics +type MetricsRecorder interface { + RecordConnectionStages( + ctx context.Context, + remotePubKey string, + connectionType metrics.ConnectionType, + isReconnection bool, + timestamps metrics.ConnectionStageTimestamps, + ) +} + type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler IFaceDiscover stdnet.ExternalIFaceDiscover RelayManager *relayClient.Manager SrWatcher *guard.SRWatcher - Semaphore *semaphoregroup.SemaphoreGroup PeerConnDispatcher *dispatcher.ConnectionDispatcher + PortForwardManager *portforward.Manager + MetricsRecorder MetricsRecorder } type WgConfig struct { @@ -77,16 +89,17 @@ type ConnConfig struct { } type Conn struct { - Log *log.Entry - mu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc - config ConnConfig - statusRecorder *Status - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - relayManager *relayClient.Manager - srWatcher *guard.SRWatcher + Log *log.Entry + mu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + config ConnConfig + statusRecorder *Status + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + relayManager *relayClient.Manager + srWatcher *guard.SRWatcher + portForwardManager *portforward.Manager onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string) @@ -111,14 +124,17 @@ type Conn struct { wgProxyRelay wgproxy.Proxy handshaker *Handshaker - guard *guard.Guard - semaphore *semaphoregroup.SemaphoreGroup - wg sync.WaitGroup + guard *guard.Guard + wg sync.WaitGroup // debug purpose dumpState *stateDump endpointUpdater *EndpointUpdater + + // Connection stage timestamps for metrics + metricsRecorder MetricsRecorder + metricsStages *MetricsStages } // NewConn creates a new not opened Conn to the remote peer. @@ -132,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: dumpState, - endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), - wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + portForwardManager: services.PortForwardManager, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: dumpState, + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + metricsRecorder: services.MetricsRecorder, } return conn, nil @@ -154,18 +171,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open(engineCtx context.Context) error { - if err := conn.semaphore.Add(engineCtx); err != nil { - return err - } - conn.mu.Lock() defer conn.mu.Unlock() if conn.opened { - conn.semaphore.Done() return nil } + // Allocate new metrics stages so old goroutines don't corrupt new state + conn.metricsStages = &MetricsStages{} + conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) @@ -173,12 +188,11 @@ func (conn *Conn) Open(engineCtx context.Context) error { relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) if err != nil { - conn.semaphore.Done() return err } conn.workerICE = workerICE - conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) + conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { @@ -207,10 +221,6 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.wg.Add(1) go func() { defer conn.wg.Done() - - conn.waitInitialRandomSleepTime(conn.ctx) - conn.semaphore.Done() - conn.guard.Start(conn.ctx, conn.onGuardEvent) }() conn.opened = true @@ -350,7 +360,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn if conn.currentConnPriority > priority { conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.statusICE.SetConnected() - conn.updateIceState(iceConnInfo) + conn.updateIceState(iceConnInfo, time.Now()) return } @@ -390,7 +400,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) - conn.enableWgWatcherIfNeeded() + updateTime := time.Now() + conn.enableWgWatcherIfNeeded(updateTime) presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { @@ -406,8 +417,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.currentConnPriority = priority conn.statusICE.SetConnected() - conn.updateIceState(iceConnInfo) - conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) + conn.updateIceState(iceConnInfo, updateTime) + conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr, updateTime) } func (conn *Conn) onICEStateDisconnected(sessionChanged bool) { @@ -459,6 +470,10 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) { conn.disableWgWatcherIfNeeded() + if conn.currentConnPriority == conntype.None { + conn.metricsStages.Disconnected() + } + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), @@ -499,7 +514,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.setRelayedProxy(wgProxy) conn.statusRelay.SetConnected() - conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, time.Now()) return } @@ -508,7 +523,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { if controller { wgProxy.Work() } - conn.enableWgWatcherIfNeeded() + updateTime := time.Now() + conn.enableWgWatcherIfNeeded(updateTime) if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) @@ -519,13 +535,16 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { if !controller { wgProxy.Work() } + + wgConfigWorkaround() + conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = conntype.Relay conn.statusRelay.SetConnected() conn.setRelayedProxy(wgProxy) - conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, updateTime) conn.Log.Infof("start to communicate with peer via relay") - conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) + conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr, updateTime) } func (conn *Conn) onRelayDisconnected() { @@ -563,6 +582,10 @@ func (conn *Conn) handleRelayDisconnectedLocked() { conn.disableWgWatcherIfNeeded() + if conn.currentConnPriority == conntype.None { + conn.metricsStages.Disconnected() + } + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), @@ -603,10 +626,10 @@ func (conn *Conn) onWGDisconnected() { } } -func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { +func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte, updateTime time.Time) { peerState := State{ PubKey: conn.config.Key, - ConnStatusUpdate: time.Now(), + ConnStatusUpdate: updateTime, ConnStatus: conn.evalStatus(), Relayed: conn.isRelayed(), RelayServerAddress: relayServerAddr, @@ -619,10 +642,10 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by } } -func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) { +func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo, updateTime time.Time) { peerState := State{ PubKey: conn.config.Key, - ConnStatusUpdate: time.Now(), + ConnStatusUpdate: updateTime, ConnStatus: conn.evalStatus(), Relayed: iceConnInfo.Relayed, LocalIceCandidateType: iceConnInfo.LocalIceCandidateType, @@ -660,29 +683,18 @@ func (conn *Conn) setStatusToDisconnected() { } } -func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) { +func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string, updateTime time.Time) { if runtime.GOOS == "ios" { runtime.GC() } + conn.metricsStages.RecordConnectionReady(updateTime) + if conn.onConnected != nil { conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr) } } -func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { - maxWait := 300 - duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond - - timeout := time.NewTimer(duration) - defer timeout.Stop() - - select { - case <-ctx.Done(): - case <-timeout.C: - } -} - func (conn *Conn) isRelayed() bool { switch conn.currentConnPriority { case conntype.Relay, conntype.ICETurn: @@ -729,14 +741,14 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { return true } -func (conn *Conn) enableWgWatcherIfNeeded() { +func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { if !conn.wgWatcher.IsEnabled() { wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx) conn.wgWatcherCancel = wgWatcherCancel conn.wgWatcherWg.Add(1) go func() { defer conn.wgWatcherWg.Done() - conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected) + conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess) }() } } @@ -811,6 +823,41 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// onWGHandshakeSuccess is called when the first WireGuard handshake is detected +func (conn *Conn) onWGHandshakeSuccess(when time.Time) { + conn.metricsStages.RecordWGHandshakeSuccess(when) + conn.recordConnectionMetrics() +} + +// recordConnectionMetrics records connection stage timestamps as metrics +func (conn *Conn) recordConnectionMetrics() { + if conn.metricsRecorder == nil { + return + } + + // Determine connection type based on current priority + conn.mu.Lock() + priority := conn.currentConnPriority + conn.mu.Unlock() + + var connType metrics.ConnectionType + switch priority { + case conntype.Relay: + connType = metrics.ConnectionTypeRelay + default: + connType = metrics.ConnectionTypeICE + } + + // Record metrics with timestamps - duration calculation happens in metrics package + conn.metricsRecorder.RecordConnectionStages( + context.Background(), + conn.config.Key, + connType, + conn.metricsStages.IsReconnection(), + conn.metricsStages.GetTimestamps(), + ) +} + // AllowedIP returns the allowed IP of the remote peer func (conn *Conn) AllowedIP() netip.Addr { return conn.config.WgConfig.AllowedIps[0].Addr() diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 32383b530..59216b647 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" - semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) var testDispatcher = dispatcher.NewConnectionDispatcher() @@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) { sd := ServiceDependencies{ SrWatcher: swWatcher, - Semaphore: semaphoregroup.NewSemaphoreGroup(1), PeerConnDispatcher: testDispatcher, } conn, err := NewConn(connConf, sd) @@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) { sd := ServiceDependencies{ StatusRecorder: NewRecorder("https://mgm"), SrWatcher: swWatcher, - Semaphore: semaphoregroup.NewSemaphoreGroup(1), PeerConnDispatcher: testDispatcher, } conn, err := NewConn(connConf, sd) @@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) { sd := ServiceDependencies{ StatusRecorder: NewRecorder("https://mgm"), SrWatcher: swWatcher, - Semaphore: semaphoregroup.NewSemaphoreGroup(1), PeerConnDispatcher: testDispatcher, } conn, err := NewConn(connConf, sd) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index aff26f847..9b50cecd1 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,12 +44,13 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + metricsStages *MetricsStages // relayListener is not blocking because the listener is using a goroutine to process the messages // and it will only keep the latest message if multiple offers are received in a short time // this is to avoid blocking the handshaker if the listener is doing some heavy processing @@ -64,13 +65,14 @@ type Handshaker struct { remoteAnswerCh chan OfferAnswer } -func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker { +func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { return &Handshaker{ log: log, config: config, signaler: signaler, ice: ice, relay: relay, + metricsStages: metricsStages, remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } @@ -89,6 +91,12 @@ func (h *Handshaker) Listen(ctx context.Context) { select { case remoteOfferAnswer := <-h.remoteOffersCh: h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + + // Record signaling received for reconnection attempts + if h.metricsStages != nil { + h.metricsStages.RecordSignalingReceived() + } + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } @@ -103,6 +111,12 @@ func (h *Handshaker) Listen(ctx context.Context) { } case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + + // Record signaling received for reconnection attempts + if h.metricsStages != nil { + h.metricsStages.RecordSignalingReceived() + } + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } diff --git a/client/internal/peer/metrics_saver.go b/client/internal/peer/metrics_saver.go new file mode 100644 index 000000000..e32afbfe5 --- /dev/null +++ b/client/internal/peer/metrics_saver.go @@ -0,0 +1,73 @@ +package peer + +import ( + "sync" + "time" + + "github.com/netbirdio/netbird/client/internal/metrics" +) + +type MetricsStages struct { + isReconnectionAttempt bool // Track if current attempt is a reconnection + stageTimestamps metrics.ConnectionStageTimestamps + mu sync.Mutex +} + +// RecordSignalingReceived records when the first signal is received from the remote peer. +// Used as the base for all subsequent stage durations to avoid inflating metrics when +// the remote peer was offline. +func (s *MetricsStages) RecordSignalingReceived() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.stageTimestamps.SignalingReceived.IsZero() { + s.stageTimestamps.SignalingReceived = time.Now() + } +} + +func (s *MetricsStages) RecordConnectionReady(when time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stageTimestamps.ConnectionReady.IsZero() { + s.stageTimestamps.ConnectionReady = when + } +} + +func (s *MetricsStages) RecordWGHandshakeSuccess(handshakeTime time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.stageTimestamps.ConnectionReady.IsZero() && s.stageTimestamps.WgHandshakeSuccess.IsZero() { + // WireGuard only reports handshake times with second precision, but ConnectionReady + // is captured with microsecond precision. If handshake appears before ConnectionReady + // due to truncation (e.g., handshake at 6.042s truncated to 6.000s), normalize to + // ConnectionReady to avoid negative duration metrics. + if handshakeTime.Before(s.stageTimestamps.ConnectionReady) { + s.stageTimestamps.WgHandshakeSuccess = s.stageTimestamps.ConnectionReady + } else { + s.stageTimestamps.WgHandshakeSuccess = handshakeTime + } + } +} + +// Disconnected sets the mode to reconnection. It is called only when both ICE and Relay have been disconnected at the same time. +func (s *MetricsStages) Disconnected() { + s.mu.Lock() + defer s.mu.Unlock() + + // Reset all timestamps for reconnection + s.stageTimestamps = metrics.ConnectionStageTimestamps{} + s.isReconnectionAttempt = true +} + +func (s *MetricsStages) IsReconnection() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.isReconnectionAttempt +} + +func (s *MetricsStages) GetTimestamps() metrics.ConnectionStageTimestamps { + s.mu.Lock() + defer s.mu.Unlock() + return s.stageTimestamps +} diff --git a/client/internal/peer/metrics_saver_test.go b/client/internal/peer/metrics_saver_test.go new file mode 100644 index 000000000..01c0aa9ac --- /dev/null +++ b/client/internal/peer/metrics_saver_test.go @@ -0,0 +1,125 @@ +package peer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/metrics" +) + +func TestMetricsStages_RecordSignalingReceived(t *testing.T) { + s := &MetricsStages{} + + s.RecordSignalingReceived() + ts := s.GetTimestamps() + require.False(t, ts.SignalingReceived.IsZero()) + + // Second call should not overwrite + first := ts.SignalingReceived + time.Sleep(time.Millisecond) + s.RecordSignalingReceived() + ts = s.GetTimestamps() + assert.Equal(t, first, ts.SignalingReceived, "should keep the first signaling timestamp") +} + +func TestMetricsStages_RecordConnectionReady(t *testing.T) { + s := &MetricsStages{} + + now := time.Now() + s.RecordConnectionReady(now) + ts := s.GetTimestamps() + assert.Equal(t, now, ts.ConnectionReady) + + // Second call should not overwrite + later := now.Add(time.Second) + s.RecordConnectionReady(later) + ts = s.GetTimestamps() + assert.Equal(t, now, ts.ConnectionReady, "should keep the first connection ready timestamp") +} + +func TestMetricsStages_RecordWGHandshakeSuccess(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + handshake := connReady.Add(500 * time.Millisecond) + s.RecordWGHandshakeSuccess(handshake) + + ts := s.GetTimestamps() + assert.Equal(t, handshake, ts.WgHandshakeSuccess) +} + +func TestMetricsStages_HandshakeBeforeConnectionReady_Normalizes(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + // WG handshake appears before ConnectionReady due to second-precision truncation + handshake := connReady.Add(-100 * time.Millisecond) + s.RecordWGHandshakeSuccess(handshake) + + ts := s.GetTimestamps() + assert.Equal(t, connReady, ts.WgHandshakeSuccess, "should normalize to ConnectionReady when handshake appears earlier") +} + +func TestMetricsStages_HandshakeIgnoredWithoutConnectionReady(t *testing.T) { + s := &MetricsStages{} + + s.RecordWGHandshakeSuccess(time.Now()) + ts := s.GetTimestamps() + assert.True(t, ts.WgHandshakeSuccess.IsZero(), "should not record handshake without connection ready") +} + +func TestMetricsStages_HandshakeRecordedOnce(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + first := connReady.Add(time.Second) + s.RecordWGHandshakeSuccess(first) + + // Second call (rekey) should be ignored + second := connReady.Add(2 * time.Second) + s.RecordWGHandshakeSuccess(second) + + ts := s.GetTimestamps() + assert.Equal(t, first, ts.WgHandshakeSuccess, "should preserve first handshake, ignore rekeys") +} + +func TestMetricsStages_Disconnected(t *testing.T) { + s := &MetricsStages{} + + s.RecordSignalingReceived() + s.RecordConnectionReady(time.Now()) + assert.False(t, s.IsReconnection()) + + s.Disconnected() + + assert.True(t, s.IsReconnection()) + ts := s.GetTimestamps() + assert.True(t, ts.SignalingReceived.IsZero(), "timestamps should be reset after disconnect") + assert.True(t, ts.ConnectionReady.IsZero(), "timestamps should be reset after disconnect") + assert.True(t, ts.WgHandshakeSuccess.IsZero(), "timestamps should be reset after disconnect") +} + +func TestMetricsStages_GetTimestamps(t *testing.T) { + s := &MetricsStages{} + + ts := s.GetTimestamps() + assert.Equal(t, metrics.ConnectionStageTimestamps{}, ts) + + now := time.Now() + s.RecordSignalingReceived() + s.RecordConnectionReady(now) + + ts = s.GetTimestamps() + assert.False(t, ts.SignalingReceived.IsZero()) + assert.Equal(t, now, ts.ConnectionReady) + assert.True(t, ts.WgHandshakeSuccess.IsZero()) +} diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 799a9375e..805a6f24a 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -48,7 +48,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin // EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. // The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management. -func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) { +func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) { w.muEnabled.Lock() if w.enabled { w.muEnabled.Unlock() @@ -56,7 +56,6 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func() } w.log.Debugf("enable WireGuard watcher") - enabledTime := time.Now() w.enabled = true w.muEnabled.Unlock() @@ -65,7 +64,7 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func() w.log.Warnf("failed to read initial wg stats: %v", err) } - w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake) + w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake) w.muEnabled.Lock() w.enabled = false @@ -89,7 +88,7 @@ func (w *WGWatcher) Reset() { } // wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time), enabledTime time.Time, initialHandshake time.Time) { w.log.Infof("WireGuard watcher started") timer := time.NewTimer(wgHandshakeOvertime) @@ -108,6 +107,9 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn if lastHandshake.IsZero() { elapsed := calcElapsed(enabledTime, *handshake) w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + if onHandshakeSuccessFn != nil { + onHandshakeSuccessFn(*handshake) + } } lastHandshake = *handshake diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go index f79405a01..3ce91cd46 100644 --- a/client/internal/peer/wg_watcher_test.go +++ b/client/internal/peer/wg_watcher_test.go @@ -35,9 +35,11 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) { defer cancel() onDisconnected := make(chan struct{}, 1) - go watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, time.Now(), func() { mlog.Infof("onDisconnectedFn") onDisconnected <- struct{}{} + }, func(when time.Time) { + mlog.Infof("onHandshakeSuccess: %v", when) }) // wait for initial reading @@ -64,7 +66,7 @@ func TestWGWatcher_ReEnable(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - watcher.EnableWgWatcher(ctx, func() {}) + watcher.EnableWgWatcher(ctx, time.Now(), func() {}, func(when time.Time) {}) }() cancel() @@ -75,9 +77,9 @@ func TestWGWatcher_ReEnable(t *testing.T) { defer cancel() onDisconnected := make(chan struct{}, 1) - go watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, time.Now(), func() { onDisconnected <- struct{}{} - }) + }, func(when time.Time) {}) time.Sleep(2 * time.Second) mocWgIface.disconnect() diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index edd70fb20..29bf5aaaa 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) @@ -61,6 +62,9 @@ type WorkerICE struct { // we record the last known state of the ICE agent to avoid duplicate on disconnected events lastKnownState ice.ConnectionState + + // portForwardAttempted tracks if we've already tried port forwarding this session + portForwardAttempted bool } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { @@ -214,6 +218,8 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { + w.portForwardAttempted = false + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() + + if candidate.Type() == ice.CandidateTypeServerReflexive { + w.injectPortForwardedCandidate(candidate) + } +} + +// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping. +func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { + pfManager := w.conn.portForwardManager + if pfManager == nil { + return + } + + mapping := pfManager.GetMapping() + if mapping == nil { + return + } + + w.muxAgent.Lock() + if w.portForwardAttempted { + w.muxAgent.Unlock() + return + } + w.portForwardAttempted = true + w.muxAgent.Unlock() + + forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping) + if err != nil { + w.log.Warnf("create forwarded candidate: %v", err) + return + } + + w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)", + forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority()) + + go func() { + if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil { + w.log.Errorf("signal port-forwarded candidate: %v", err) + } + }() +} + +// createForwardedCandidate creates a new server reflexive candidate with the forwarded port. +// It uses the NAT gateway's external IP with the forwarded port. +func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) { + var externalIP string + if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() { + externalIP = mapping.ExternalIP.String() + } else { + // Fallback to STUN-discovered address if NAT didn't provide external IP + externalIP = srflxCandidate.Address() + } + + // Per RFC 8445, the related address for srflx is the base (host candidate address). + // If the original srflx has unspecified related address, use its own address as base. + relAddr := srflxCandidate.RelatedAddress().Address + if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" { + relAddr = srflxCandidate.Address() + } + + // Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates + // over regular srflx during ICE connectivity checks. + priority := srflxCandidate.Priority() + 1000 + + candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + Network: srflxCandidate.NetworkType().String(), + Address: externalIP, + Port: int(mapping.ExternalPort), + Component: srflxCandidate.Component(), + Priority: priority, + RelAddr: relAddr, + RelPort: int(mapping.InternalPort), + }) + if err != nil { + return nil, fmt.Errorf("create candidate: %w", err) + } + + for _, e := range srflxCandidate.Extensions() { + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = srflxCandidate.ID() + } + if err := candidate.AddExtension(e); err != nil { + return nil, fmt.Errorf("add extension: %w", err) + } + } + + return candidate, nil } func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { @@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { if !lok || !rok { continue } - w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms", sessionID, - local.NetworkType(), local.Type(), local.Address(), - remote.NetworkType(), remote.Type(), remote.Address(), + local.NetworkType(), local.Type(), local.Address(), local.Port(), + remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(), stat.CurrentRoundTripTime*1000) } } diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go new file mode 100644 index 000000000..ba83c79bf --- /dev/null +++ b/client/internal/portforward/env.go @@ -0,0 +1,35 @@ +package portforward + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK" +) + +func isDisabledByEnv() bool { + return parseBoolEnv(envDisableNATMapper) +} + +func isHealthCheckDisabled() bool { + return parseBoolEnv(envDisablePCPHealthCheck) +} + +func parseBoolEnv(key string) bool { + val := os.Getenv(key) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", key, err) + return false + } + return disabled +} diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go new file mode 100644 index 000000000..b0680160c --- /dev/null +++ b/client/internal/portforward/manager.go @@ -0,0 +1,342 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "regexp" + "sync" + "time" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +const ( + defaultMappingTTL = 2 * time.Hour + healthCheckInterval = 1 * time.Minute + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" +) + +// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML, +// allowing for whitespace/newlines between tags from different router firmware. +var upnpErrPermanentLeaseOnly = regexp.MustCompile(`\s*725\s*`) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// TODO: persist mapping state for crash recovery cleanup of permanent leases. +// Currently not done because State.Cleanup requires NAT gateway re-discovery, +// which blocks startup for ~10s when no gateway is present (affects all clients). + +type Manager struct { + cancel context.CancelFunc + + mapping *Mapping + mappingLock sync.Mutex + + wgPort uint16 + + done chan struct{} + stopCtx chan context.Context + + // protect exported functions + mu sync.Mutex +} + +// NewManager creates a new port forwarding manager. +func NewManager() *Manager { + return &Manager{ + stopCtx: make(chan context.Context, 1), + } +} + +func (m *Manager) Start(ctx context.Context, wgPort uint16) { + m.mu.Lock() + if m.cancel != nil { + m.mu.Unlock() + return + } + + if isDisabledByEnv() { + log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + m.mu.Unlock() + return + } + + if wgPort == 0 { + log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + m.mu.Unlock() + return + } + m.wgPort = wgPort + + m.done = make(chan struct{}) + defer close(m.done) + + ctx, m.cancel = context.WithCancel(ctx) + m.mu.Unlock() + + gateway, mapping, err := m.setup(ctx) + if err != nil { + log.Infof("port forwarding setup: %v", err) + return + } + + m.mappingLock.Lock() + m.mapping = mapping + m.mappingLock.Unlock() + + m.renewLoop(ctx, gateway, mapping.TTL) + + select { + case cleanupCtx := <-m.stopCtx: + // block the Start while cleaned up gracefully + m.cleanup(cleanupCtx, gateway) + default: + // return Start immediately and cleanup in background + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second) + go func() { + defer cleanupCancel() + m.cleanup(cleanupCtx, gateway) + }() + } +} + +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mappingLock.Lock() + defer m.mappingLock.Unlock() + + if m.mapping == nil { + return nil + } + + mapping := *m.mapping + return &mapping +} + +// GracefullyStop cancels the manager and attempts to delete the port mapping. +// After GracefullyStop returns, the manager cannot be restarted. +func (m *Manager) GracefullyStop(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel == nil { + return nil + } + + // Send cleanup context before cancelling, so Start picks it up after renewLoop exits. + m.startTearDown(ctx) + + m.cancel() + m.cancel = nil + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } +} + +func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { + discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) + defer discoverCancel() + + gateway, err := discoverGateway(discoverCtx) + if err != nil { + return nil, nil, fmt.Errorf("discover gateway: %w", err) + } + + log.Infof("discovered NAT gateway: %s", gateway.Type()) + + mapping, err := m.createMapping(ctx, gateway) + if err != nil { + return nil, nil, fmt.Errorf("create port mapping: %w", err) + } + return gateway, mapping, nil +} + +func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ttl := defaultMappingTTL + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + if !isPermanentLeaseRequired(err) { + return nil, err + } + log.Infof("gateway only supports permanent leases, retrying with indefinite duration") + ttl = 0 + externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + return nil, err + } + } + + externalIP, err := gateway.GetExternalAddress() + if err != nil { + log.Debugf("failed to get external address: %v", err) + } + + mapping := &Mapping{ + Protocol: "udp", + InternalPort: m.wgPort, + ExternalPort: uint16(externalPort), + ExternalIP: externalIP, + NATType: gateway.Type(), + TTL: ttl, + } + + log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", + m.wgPort, externalPort, gateway.Type(), externalIP) + return mapping, nil +} + +func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) { + if ttl == 0 { + // Permanent mappings don't expire, just wait for cancellation + // but still run health checks for PCP gateways. + m.permanentLeaseLoop(ctx, gateway) + return + } + + renewTicker := time.NewTicker(ttl / 2) + healthTicker := time.NewTicker(healthCheckInterval) + defer renewTicker.Stop() + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-renewTicker.C: + if err := m.renewMapping(ctx, gateway); err != nil { + log.Warnf("failed to renew port mapping: %v", err) + continue + } + case <-healthTicker.C: + if m.checkHealthAndRecreate(ctx, gateway) { + renewTicker.Reset(ttl / 2) + } + } + } +} + +func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) { + healthTicker := time.NewTicker(healthCheckInterval) + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-healthTicker.C: + m.checkHealthAndRecreate(ctx, gateway) + } + } +} + +func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool { + if isHealthCheckDisabled() { + return false + } + + m.mappingLock.Lock() + hasMapping := m.mapping != nil + m.mappingLock.Unlock() + + if !hasMapping { + return false + } + + pcpNAT, ok := gateway.(*pcp.NAT) + if !ok { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx) + if err != nil { + log.Debugf("PCP health check failed: %v", err) + return false + } + + if serverRestarted { + log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch) + if err := m.renewMapping(ctx, gateway); err != nil { + log.Errorf("failed to recreate port mapping after server restart: %v", err) + return false + } + return true + } + + return false +} + +func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL) + if err != nil { + return fmt.Errorf("add port mapping: %w", err) + } + + if uint16(externalPort) != m.mapping.ExternalPort { + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) + m.mappingLock.Lock() + m.mapping.ExternalPort = uint16(externalPort) + m.mappingLock.Unlock() + } + + log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) + return nil +} + +func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { + m.mappingLock.Lock() + mapping := m.mapping + m.mapping = nil + m.mappingLock.Unlock() + + if mapping == nil { + return + } + + if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { + log.Warnf("delete port mapping on stop: %v", err) + return + } + + log.Infof("deleted port mapping for port %d", mapping.InternalPort) +} + +func (m *Manager) startTearDown(ctx context.Context) { + select { + case m.stopCtx <- ctx: + default: + } +} + +// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725). +func isPermanentLeaseRequired(err error) bool { + return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error()) +} diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go new file mode 100644 index 000000000..36c55063b --- /dev/null +++ b/client/internal/portforward/manager_js.go @@ -0,0 +1,39 @@ +package portforward + +import ( + "context" + "net" + "time" +) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported. +type Manager struct{} + +// NewManager returns a stub manager for js/wasm builds. +func NewManager() *Manager { + return &Manager{} +} + +// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments. +func (m *Manager) Start(context.Context, uint16) { + // no NAT traversal in wasm +} + +// GracefullyStop is a no-op on js/wasm. +func (m *Manager) GracefullyStop(context.Context) error { return nil } + +// GetMapping always returns nil on js/wasm. +func (m *Manager) GetMapping() *Mapping { + return nil +} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go new file mode 100644 index 000000000..1f66f9ccd --- /dev/null +++ b/client/internal/portforward/manager_test.go @@ -0,0 +1,201 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockNAT struct { + natType string + deviceAddr net.IP + externalAddr net.IP + internalAddr net.IP + mappings map[int]int + addMappingErr error + deleteMappingErr error + onlyPermanentLeases bool + lastTimeout time.Duration +} + +func newMockNAT() *mockNAT { + return &mockNAT{ + natType: "Mock-NAT", + deviceAddr: net.ParseIP("192.168.1.1"), + externalAddr: net.ParseIP("203.0.113.50"), + internalAddr: net.ParseIP("192.168.1.100"), + mappings: make(map[int]int), + } +} + +func (m *mockNAT) Type() string { + return m.natType +} + +func (m *mockNAT) GetDeviceAddress() (net.IP, error) { + return m.deviceAddr, nil +} + +func (m *mockNAT) GetExternalAddress() (net.IP, error) { + return m.externalAddr, nil +} + +func (m *mockNAT) GetInternalAddress() (net.IP, error) { + return m.internalAddr, nil +} + +func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) { + if m.addMappingErr != nil { + return 0, m.addMappingErr + } + if m.onlyPermanentLeases && timeout != 0 { + return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: 725OnlyPermanentLeasesSupported") + } + externalPort := internalPort + m.mappings[internalPort] = externalPort + m.lastTimeout = timeout + return externalPort, nil +} + +func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if m.deleteMappingErr != nil { + return m.deleteMappingErr + } + delete(m.mappings, internalPort) + return nil +} + +func TestManager_CreateMapping(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, "udp", mapping.Protocol) + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, uint16(51820), mapping.ExternalPort) + assert.Equal(t, "Mock-NAT", mapping.NATType) + assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4()) + assert.Equal(t, defaultMappingTTL, mapping.TTL) +} + +func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) { + m := NewManager() + assert.Nil(t, m.GetMapping()) +} + +func TestManager_GetMapping_ReturnsCopy(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + mapping := m.GetMapping() + require.NotNil(t, mapping) + assert.Equal(t, uint16(51820), mapping.InternalPort) + + // Mutating the returned copy should not affect the manager's mapping. + mapping.ExternalPort = 9999 + assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort) +} + +func TestManager_Cleanup_DeletesMapping(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + gateway := newMockNAT() + // Seed the mock so we can verify deletion. + gateway.mappings[51820] = 51820 + + m.cleanup(context.Background(), gateway) + + _, exists := gateway.mappings[51820] + assert.False(t, exists, "mapping should be deleted from gateway") + assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared") +} + +func TestManager_Cleanup_NilMapping(t *testing.T) { + m := NewManager() + gateway := newMockNAT() + + // Should not panic or call gateway. + m.cleanup(context.Background(), gateway) +} + + +func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + gateway.onlyPermanentLeases = true + + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease") + assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration") +} + +func TestIsPermanentLeaseRequired(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "UPnP error 725", + err: fmt.Errorf("SOAP fault. Code: | Detail: 725OnlyPermanentLeasesSupported"), + expected: true, + }, + { + name: "wrapped error with 725", + err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: 725")), + expected: true, + }, + { + name: "error 725 with newlines in XML", + err: fmt.Errorf("\n 725\n"), + expected: true, + }, + { + name: "bare 725 without XML tag", + err: fmt.Errorf("error code 725"), + expected: false, + }, + { + name: "unrelated error", + err: fmt.Errorf("connection refused"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err)) + }) + } +} diff --git a/client/internal/portforward/pcp/client.go b/client/internal/portforward/pcp/client.go new file mode 100644 index 000000000..f6d243ef9 --- /dev/null +++ b/client/internal/portforward/pcp/client.go @@ -0,0 +1,408 @@ +package pcp + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultTimeout = 3 * time.Second + responseBufferSize = 128 + + // RFC 6887 Section 8.1.1 retry timing + initialRetryDelay = 3 * time.Second + maxRetryDelay = 1024 * time.Second + maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case +) + +// Client is a PCP protocol client. +// All methods are safe for concurrent use. +type Client struct { + gateway netip.Addr + timeout time.Duration + + mu sync.Mutex + // localIP caches the resolved local IP address. + localIP netip.Addr + // lastEpoch is the last observed server epoch value. + lastEpoch uint32 + // epochTime tracks when lastEpoch was received for state loss detection. + epochTime time.Time + // externalIP caches the external IP from the last successful MAP response. + externalIP netip.Addr + // epochStateLost is set when epoch indicates server restart. + epochStateLost bool +} + +// NewClient creates a new PCP client for the gateway at the given IP. +func NewClient(gateway net.IP) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: defaultTimeout, + } +} + +// NewClientWithTimeout creates a new PCP client with a custom timeout. +func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: timeout, + } +} + +// SetLocalIP sets the local IP address to use in PCP requests. +func (c *Client) SetLocalIP(ip net.IP) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Debugf("invalid local IP: %v", ip) + } + c.mu.Lock() + c.localIP = addr.Unmap() + c.mu.Unlock() +} + +// Gateway returns the gateway IP address. +func (c *Client) Gateway() net.IP { + return c.gateway.AsSlice() +} + +// Announce sends a PCP ANNOUNCE request to discover PCP support. +// Returns the server's epoch time on success. +func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) { + localIP, err := c.getLocalIP() + if err != nil { + return 0, fmt.Errorf("get local IP: %w", err) + } + + req := buildAnnounceRequest(localIP) + resp, err := c.sendRequest(ctx, req) + if err != nil { + return 0, fmt.Errorf("send announce: %w", err) + } + + parsed, err := parseResponse(resp) + if err != nil { + return 0, fmt.Errorf("parse announce response: %w", err) + } + + if parsed.ResultCode != ResultSuccess { + return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode)) + } + + c.mu.Lock() + if c.updateEpochLocked(parsed.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.mu.Unlock() + return parsed.Epoch, nil +} + +// AddPortMapping requests a port mapping from the PCP server. +func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) { + return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime) +} + +// AddPortMappingWithHint requests a port mapping with suggested external port and IP. +// Use lifetime <= 0 to delete a mapping. +func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) { + var extIP netip.Addr + if suggestedExtIP != nil { + var ok bool + extIP, ok = netip.AddrFromSlice(suggestedExtIP) + if !ok { + log.Debugf("invalid suggested external IP: %v", suggestedExtIP) + } + extIP = extIP.Unmap() + } + return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime) +} + +func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) { + localIP, err := c.getLocalIP() + if err != nil { + return nil, fmt.Errorf("get local IP: %w", err) + } + + proto, err := protocolNumber(protocol) + if err != nil { + return nil, fmt.Errorf("parse protocol: %w", err) + } + + var nonce [12]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + // Convert lifetime to seconds. Lifetime 0 means delete, so only apply + // default for positive durations that round to 0 seconds. + var lifetimeSec uint32 + if lifetime > 0 { + lifetimeSec = uint32(lifetime.Seconds()) + if lifetimeSec == 0 { + lifetimeSec = DefaultLifetime + } + } + + req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec) + + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("send map request: %w", err) + } + + mapResp, err := parseMapResponse(resp) + if err != nil { + return nil, fmt.Errorf("parse map response: %w", err) + } + + if mapResp.Nonce != nonce { + return nil, fmt.Errorf("nonce mismatch in response") + } + + if mapResp.Protocol != proto { + return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol) + } + if mapResp.InternalPort != uint16(internalPort) { + return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort) + } + + if mapResp.ResultCode != ResultSuccess { + return nil, &Error{ + Code: mapResp.ResultCode, + Message: ResultCodeString(mapResp.ResultCode), + } + } + + c.mu.Lock() + if c.updateEpochLocked(mapResp.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.cacheExternalIPLocked(mapResp.ExternalIP) + c.mu.Unlock() + return mapResp, nil +} + +// DeletePortMapping removes a port mapping by requesting zero lifetime. +func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil { + var pcpErr *Error + if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized { + return nil + } + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// GetExternalAddress returns the external IP address. +// First checks for a cached value from previous MAP responses. +// If not cached, creates a short-lived mapping to discover the external IP. +func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) { + c.mu.Lock() + if c.externalIP.IsValid() { + ip := c.externalIP.AsSlice() + c.mu.Unlock() + return ip, nil + } + c.mu.Unlock() + + // Use an ephemeral port in the dynamic range (49152-65535). + // Port 0 is not valid with UDP/TCP protocols per RFC 6887. + ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152) + + // Use minimal lifetime (1 second) for discovery. + resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second) + if err != nil { + return nil, fmt.Errorf("create temporary mapping: %w", err) + } + + if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil { + log.Debugf("cleanup temporary PCP mapping: %v", err) + } + + return resp.ExternalIP.AsSlice(), nil +} + +// LastEpoch returns the last observed server epoch value. +// A decrease in epoch indicates the server may have restarted and mappings may be lost. +func (c *Client) LastEpoch() uint32 { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastEpoch +} + +// EpochStateLost returns true if epoch state loss was detected and clears the flag. +func (c *Client) EpochStateLost() bool { + c.mu.Lock() + defer c.mu.Unlock() + lost := c.epochStateLost + c.epochStateLost = false + return lost +} + +// updateEpoch updates the epoch tracking and detects potential state loss. +// Returns true if state loss was detected (server likely restarted). +// Caller must hold c.mu. +func (c *Client) updateEpochLocked(newEpoch uint32) bool { + now := time.Now() + stateLost := false + + // RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss. + // client_delta = time since last response + // server_delta = epoch change since last response + // Invalid if: client_delta+2 < server_delta - server_delta/16 + // OR: server_delta+2 < client_delta - client_delta/16 + // The +2 handles quantization, /16 (6.25%) handles clock drift. + if !c.epochTime.IsZero() && c.lastEpoch > 0 { + clientDelta := uint32(now.Sub(c.epochTime).Seconds()) + serverDelta := newEpoch - c.lastEpoch + + // Check for epoch going backwards or jumping unexpectedly. + // Subtraction is safe: serverDelta/16 is always <= serverDelta. + if clientDelta+2 < serverDelta-(serverDelta/16) || + serverDelta+2 < clientDelta-(clientDelta/16) { + stateLost = true + c.epochStateLost = true + } + } + + c.lastEpoch = newEpoch + c.epochTime = now + return stateLost +} + +// cacheExternalIP stores the external IP from a successful MAP response. +// Caller must hold c.mu. +func (c *Client) cacheExternalIPLocked(ip netip.Addr) { + if ip.IsValid() && !ip.IsUnspecified() { + c.externalIP = ip + } +} + +// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1. +func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) { + addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port} + + var lastErr error + delay := initialRetryDelay + + for range maxRetries { + resp, err := c.sendOnce(ctx, addr, req) + if err == nil { + return resp, nil + } + lastErr = err + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT) + // RAND is random between -0.1 and +0.1 + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelayWithJitter(delay)): + } + delay = min(delay*2, maxRetryDelay) + } + + return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr) +} + +// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1]. +func retryDelayWithJitter(d time.Duration) time.Duration { + var b [1]byte + _, _ = rand.Read(b[:]) + // Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1 + jitter := (float64(b[0])/255.0)*0.2 - 0.1 + return time.Duration(float64(d) * (1 + jitter)) +} + +func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) { + // Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3. + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, fmt.Errorf("listen: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("close UDP connection: %v", err) + } + }() + + timeout := c.timeout + if deadline, ok := ctx.Deadline(); ok { + if remaining := time.Until(deadline); remaining < timeout { + timeout = remaining + } + } + + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("set deadline: %w", err) + } + + if _, err := conn.WriteToUDP(req, addr); err != nil { + return nil, fmt.Errorf("write: %w", err) + } + + resp := make([]byte, responseBufferSize) + n, from, err := conn.ReadFromUDP(resp) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + + // RFC 6887 §8.3: Validate response came from expected PCP server. + if !from.IP.Equal(addr.IP) { + return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP) + } + + return resp[:n], nil +} + +func (c *Client) getLocalIP() (netip.Addr, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.localIP.IsValid() { + return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway) + } + return c.localIP, nil +} + +func protocolNumber(protocol string) (uint8, error) { + switch protocol { + case "udp", "UDP": + return ProtoUDP, nil + case "tcp", "TCP": + return ProtoTCP, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} + +// Error represents a PCP error response. +type Error struct { + Code uint8 + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code) +} diff --git a/client/internal/portforward/pcp/client_test.go b/client/internal/portforward/pcp/client_test.go new file mode 100644 index 000000000..79f44a426 --- /dev/null +++ b/client/internal/portforward/pcp/client_test.go @@ -0,0 +1,187 @@ +package pcp + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddrConversion(t *testing.T) { + tests := []struct { + name string + addr netip.Addr + }{ + {"IPv4", netip.MustParseAddr("192.168.1.100")}, + {"IPv4 loopback", netip.MustParseAddr("127.0.0.1")}, + {"IPv6", netip.MustParseAddr("2001:db8::1")}, + {"IPv6 loopback", netip.MustParseAddr("::1")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b16 := addrTo16(tt.addr) + + recovered := addrFrom16(b16) + assert.Equal(t, tt.addr, recovered, "address should round-trip") + }) + } +} + +func TestBuildAnnounceRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + req := buildAnnounceRequest(clientIP) + + require.Len(t, req, headerSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpAnnounce), req[1], "opcode") + + // Check client IP is properly encoded as IPv4-mapped IPv6 + assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10") + assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11") + assert.Equal(t, byte(192), req[20], "IP octet 1") + assert.Equal(t, byte(168), req[21], "IP octet 2") + assert.Equal(t, byte(1), req[22], "IP octet 3") + assert.Equal(t, byte(100), req[23], "IP octet 4") +} + +func TestBuildMapRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600) + + require.Len(t, req, mapRequestSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpMap), req[1], "opcode") + + // Lifetime at bytes 4-7 + assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime") + + // Nonce at bytes 24-35 + assert.Equal(t, nonce[:], req[24:36], "nonce") + + // Protocol at byte 36 + assert.Equal(t, byte(ProtoUDP), req[36], "protocol") + + // Internal port at bytes 40-41 + assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port") + + // External port at bytes 42-43 + assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port") +} + +func TestParseResponse(t *testing.T) { + // Construct a valid ANNOUNCE response + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce | OpReply + // Result code = 0 (success) + // Lifetime = 0 + // Epoch = 12345 + resp[8] = 0 + resp[9] = 0 + resp[10] = 0x30 + resp[11] = 0x39 + + parsed, err := parseResponse(resp) + require.NoError(t, err) + assert.Equal(t, uint8(Version), parsed.Version) + assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode) + assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode) + assert.Equal(t, uint32(12345), parsed.Epoch) +} + +func TestParseResponseErrors(t *testing.T) { + t.Run("too short", func(t *testing.T) { + _, err := parseResponse([]byte{1, 2, 3}) + assert.Error(t, err) + }) + + t.Run("wrong version", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = 1 // Wrong version + resp[1] = OpReply + _, err := parseResponse(resp) + assert.Error(t, err) + }) + + t.Run("missing reply bit", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce // Missing OpReply bit + _, err := parseResponse(resp) + assert.Error(t, err) + }) +} + +func TestResultCodeString(t *testing.T) { + assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess)) + assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized)) + assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch)) + assert.Contains(t, ResultCodeString(255), "UNKNOWN") +} + +func TestProtocolNumber(t *testing.T) { + proto, err := protocolNumber("udp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + proto, err = protocolNumber("tcp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoTCP), proto) + + proto, err = protocolNumber("UDP") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + _, err = protocolNumber("icmp") + assert.Error(t, err) +} + +func TestClientCreation(t *testing.T) { + gateway := netip.MustParseAddr("192.168.1.1").AsSlice() + + client := NewClient(gateway) + assert.Equal(t, net.IP(gateway), client.Gateway()) + assert.Equal(t, defaultTimeout, client.timeout) + + clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second) + assert.Equal(t, 5*time.Second, clientWithTimeout.timeout) +} + +func TestNATType(t *testing.T) { + n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice()) + assert.Equal(t, "PCP", n.Type()) +} + +// Integration test - skipped unless PCP_TEST_GATEWAY env is set +func TestClientIntegration(t *testing.T) { + t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=") + + gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway + localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP + + client := NewClient(gateway) + client.SetLocalIP(localIP) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Test ANNOUNCE + epoch, err := client.Announce(ctx) + require.NoError(t, err) + t.Logf("Server epoch: %d", epoch) + + // Test MAP + resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour) + require.NoError(t, err) + t.Logf("Mapping: internal=%d external=%d externalIP=%s", + resp.InternalPort, resp.ExternalPort, resp.ExternalIP) + + // Cleanup + err = client.DeletePortMapping(ctx, "udp", 51820) + require.NoError(t, err) +} diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go new file mode 100644 index 000000000..1dc24274b --- /dev/null +++ b/client/internal/portforward/pcp/nat.go @@ -0,0 +1,209 @@ +package pcp + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/libp2p/go-nat" + "github.com/libp2p/go-netroute" +) + +var _ nat.NAT = (*NAT)(nil) + +// NAT implements the go-nat NAT interface using PCP. +// Supports dual-stack (IPv4 and IPv6) when available. +// All methods are safe for concurrent use. +// +// TODO: IPv6 pinholes use the local IPv6 address. If the address changes +// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale +// and needs to be recreated with the new address. +type NAT struct { + client *Client + + mu sync.RWMutex + // client6 is the IPv6 PCP client, nil if IPv6 is unavailable. + client6 *Client + // localIP6 caches the local IPv6 address used for PCP requests. + localIP6 netip.Addr +} + +// NewNAT creates a new NAT instance backed by PCP. +func NewNAT(gateway, localIP net.IP) *NAT { + client := NewClient(gateway) + client.SetLocalIP(localIP) + return &NAT{ + client: client, + } +} + +// Type returns "PCP" as the NAT type. +func (n *NAT) Type() string { + return "PCP" +} + +// GetDeviceAddress returns the gateway IP address. +func (n *NAT) GetDeviceAddress() (net.IP, error) { + return n.client.Gateway(), nil +} + +// GetExternalAddress returns the external IP address. +func (n *NAT) GetExternalAddress() (net.IP, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return n.client.GetExternalAddress(ctx) +} + +// GetInternalAddress returns the local IP address used to communicate with the gateway. +func (n *NAT) GetInternalAddress() (net.IP, error) { + addr, err := n.client.getLocalIP() + if err != nil { + return nil, err + } + return addr.AsSlice(), nil +} + +// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available). +func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) { + resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout) + if err != nil { + return 0, fmt.Errorf("add mapping: %w", err) + } + + n.mu.RLock() + client6 := n.client6 + localIP6 := n.localIP6 + n.mu.RUnlock() + + if client6 == nil { + return int(resp.ExternalPort), nil + } + + if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil { + log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err) + return int(resp.ExternalPort), nil + } + + log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort) + return int(resp.ExternalPort), nil +} + +// DeletePortMapping removes a port mapping from both IPv4 and IPv6. +func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + err := n.client.DeletePortMapping(ctx, protocol, internalPort) + + n.mu.RLock() + client6 := n.client6 + n.mu.RUnlock() + + if client6 != nil { + if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil { + log.Warnf("IPv6 PCP delete mapping failed: %v", err6) + } + } + + if err != nil { + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive. +// Returns the current epoch and whether the server may have restarted (epoch state loss detected). +func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) { + epoch, err = n.client.Announce(ctx) + if err != nil { + return 0, false, fmt.Errorf("announce: %w", err) + } + return epoch, n.client.EpochStateLost(), nil +} + +// DiscoverPCP attempts to discover a PCP-capable gateway. +// Returns a NAT interface if PCP is supported, or an error otherwise. +// Discovers both IPv4 and IPv6 gateways when available. +func DiscoverPCP(ctx context.Context) (nat.NAT, error) { + gateway, localIP, err := getDefaultGateway() + if err != nil { + return nil, fmt.Errorf("get default gateway: %w", err) + } + + client := NewClient(gateway) + client.SetLocalIP(localIP) + if _, err := client.Announce(ctx); err != nil { + return nil, fmt.Errorf("PCP announce: %w", err) + } + + result := &NAT{client: client} + discoverIPv6(ctx, result) + + return result, nil +} + +func discoverIPv6(ctx context.Context, result *NAT) { + gateway6, localIP6, err := getDefaultGateway6() + if err != nil { + log.Debugf("IPv6 gateway discovery failed: %v", err) + return + } + + client6 := NewClient(gateway6) + client6.SetLocalIP(localIP6) + if _, err := client6.Announce(ctx); err != nil { + log.Debugf("PCP IPv6 announce failed: %v", err) + return + } + + addr, ok := netip.AddrFromSlice(localIP6) + if !ok { + log.Debugf("invalid IPv6 local IP: %v", localIP6) + return + } + result.mu.Lock() + result.client6 = client6 + result.localIP6 = addr + result.mu.Unlock() + log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6) +} + +// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table. +func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv4zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} + +// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table. +func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv6zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} diff --git a/client/internal/portforward/pcp/protocol.go b/client/internal/portforward/pcp/protocol.go new file mode 100644 index 000000000..d81c50c8c --- /dev/null +++ b/client/internal/portforward/pcp/protocol.go @@ -0,0 +1,225 @@ +// Package pcp implements the Port Control Protocol (RFC 6887). +// +// # Implemented Features +// +// - ANNOUNCE opcode: Discovers PCP server support +// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6) +// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients +// - Nonce validation: Prevents response spoofing +// - Epoch tracking: Detects server restarts per Section 8.5 +// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1) +// +// # Not Implemented +// +// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal) +// - THIRD_PARTY option: For managing mappings on behalf of other devices +// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing) +// - FILTER option: To restrict remote peer addresses +// +// These optional features are omitted because the primary use case is simple +// port forwarding for WireGuard, which only requires MAP with default behavior. +package pcp + +import ( + "encoding/binary" + "fmt" + "net/netip" +) + +const ( + // Version is the PCP protocol version (RFC 6887). + Version = 2 + + // Port is the standard PCP server port. + Port = 5351 + + // DefaultLifetime is the default requested mapping lifetime in seconds. + DefaultLifetime = 7200 // 2 hours + + // Header sizes + headerSize = 24 + mapPayloadSize = 36 + mapRequestSize = headerSize + mapPayloadSize // 60 bytes +) + +// Opcodes +const ( + OpAnnounce = 0 + OpMap = 1 + OpPeer = 2 + OpReply = 0x80 // OR'd with opcode in responses +) + +// Protocol numbers for MAP requests +const ( + ProtoUDP = 17 + ProtoTCP = 6 +) + +// Result codes (RFC 6887 Section 7.4) +const ( + ResultSuccess = 0 + ResultUnsuppVersion = 1 + ResultNotAuthorized = 2 + ResultMalformedRequest = 3 + ResultUnsuppOpcode = 4 + ResultUnsuppOption = 5 + ResultMalformedOption = 6 + ResultNetworkFailure = 7 + ResultNoResources = 8 + ResultUnsuppProtocol = 9 + ResultUserExQuota = 10 + ResultCannotProvideExt = 11 + ResultAddressMismatch = 12 + ResultExcessiveRemotePeers = 13 +) + +// ResultCodeString returns a human-readable string for a result code. +func ResultCodeString(code uint8) string { + switch code { + case ResultSuccess: + return "SUCCESS" + case ResultUnsuppVersion: + return "UNSUPP_VERSION" + case ResultNotAuthorized: + return "NOT_AUTHORIZED" + case ResultMalformedRequest: + return "MALFORMED_REQUEST" + case ResultUnsuppOpcode: + return "UNSUPP_OPCODE" + case ResultUnsuppOption: + return "UNSUPP_OPTION" + case ResultMalformedOption: + return "MALFORMED_OPTION" + case ResultNetworkFailure: + return "NETWORK_FAILURE" + case ResultNoResources: + return "NO_RESOURCES" + case ResultUnsuppProtocol: + return "UNSUPP_PROTOCOL" + case ResultUserExQuota: + return "USER_EX_QUOTA" + case ResultCannotProvideExt: + return "CANNOT_PROVIDE_EXTERNAL" + case ResultAddressMismatch: + return "ADDRESS_MISMATCH" + case ResultExcessiveRemotePeers: + return "EXCESSIVE_REMOTE_PEERS" + default: + return fmt.Sprintf("UNKNOWN(%d)", code) + } +} + +// Response represents a parsed PCP response header. +type Response struct { + Version uint8 + Opcode uint8 + ResultCode uint8 + Lifetime uint32 + Epoch uint32 +} + +// MapResponse contains the full response to a MAP request. +type MapResponse struct { + Response + Nonce [12]byte + Protocol uint8 + InternalPort uint16 + ExternalPort uint16 + ExternalIP netip.Addr +} + +// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation. +func addrTo16(addr netip.Addr) [16]byte { + if addr.Is4() { + return netip.AddrFrom4(addr.As4()).As16() + } + return addr.As16() +} + +// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4. +func addrFrom16(b [16]byte) netip.Addr { + return netip.AddrFrom16(b).Unmap() +} + +// buildAnnounceRequest creates a PCP ANNOUNCE request packet. +func buildAnnounceRequest(clientIP netip.Addr) []byte { + req := make([]byte, headerSize) + req[0] = Version + req[1] = OpAnnounce + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + return req +} + +// buildMapRequest creates a PCP MAP request packet. +func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte { + req := make([]byte, mapRequestSize) + + // Header + req[0] = Version + req[1] = OpMap + binary.BigEndian.PutUint32(req[4:8], lifetime) + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + + // MAP payload + copy(req[24:36], nonce[:]) + req[36] = protocol + binary.BigEndian.PutUint16(req[40:42], internalPort) + binary.BigEndian.PutUint16(req[42:44], suggestedExtPort) + if suggestedExtIP.IsValid() { + extMapped := addrTo16(suggestedExtIP) + copy(req[44:60], extMapped[:]) + } + + return req +} + +// parseResponse parses the common PCP response header. +func parseResponse(data []byte) (*Response, error) { + if len(data) < headerSize { + return nil, fmt.Errorf("response too short: %d bytes", len(data)) + } + + resp := &Response{ + Version: data[0], + Opcode: data[1], + ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2) + Lifetime: binary.BigEndian.Uint32(data[4:8]), + Epoch: binary.BigEndian.Uint32(data[8:12]), + } + + if resp.Version != Version { + return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version) + } + + if resp.Opcode&OpReply == 0 { + return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode) + } + + return resp, nil +} + +// parseMapResponse parses a complete MAP response. +func parseMapResponse(data []byte) (*MapResponse, error) { + if len(data) < mapRequestSize { + return nil, fmt.Errorf("MAP response too short: %d bytes", len(data)) + } + + resp, err := parseResponse(data) + if err != nil { + return nil, fmt.Errorf("parse header: %w", err) + } + + mapResp := &MapResponse{ + Response: *resp, + Protocol: data[36], + InternalPort: binary.BigEndian.Uint16(data[40:42]), + ExternalPort: binary.BigEndian.Uint16(data[42:44]), + ExternalIP: addrFrom16([16]byte(data[44:60])), + } + copy(mapResp.Nonce[:], data[24:36]) + + return mapResp, nil +} diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go new file mode 100644 index 000000000..b1315cdc0 --- /dev/null +++ b/client/internal/portforward/state.go @@ -0,0 +1,63 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +// discoverGateway is the function used for NAT gateway discovery. +// It can be replaced in tests to avoid real network operations. +// Tries PCP first, then falls back to NAT-PMP/UPnP. +var discoverGateway = defaultDiscoverGateway + +func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) { + pcpGateway, err := pcp.DiscoverPCP(ctx) + if err == nil { + return pcpGateway, nil + } + log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err) + + return nat.DiscoverGateway(ctx) +} + +// State is persisted only for crash recovery cleanup +type State struct { + InternalPort uint16 `json:"internal_port,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +func (s *State) Name() string { + return "port_forward_state" +} + +// Cleanup implements statemanager.CleanableState for crash recovery +func (s *State) Cleanup() error { + if s.InternalPort == 0 { + return nil + } + + log.Infof("cleaning up stale port mapping for port %d", s.InternalPort) + + ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + gateway, err := discoverGateway(ctx) + if err != nil { + // Discovery failure is not an error - gateway may not exist + log.Debugf("cleanup: no gateway found: %v", err) + return nil + } + + if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + + return nil +} diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 8f3ff8b11..20c615d57 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -39,6 +39,18 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) +// mgmProber is the subset of management client needed for URL migration probes. +type mgmProber interface { + HealthCheck() error + Close() error +} + +// newMgmProber creates a management client for probing URL reachability. +// Overridden in tests to avoid real network calls. +var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) { + return mgm.NewClient(ctx, addr, key, tlsEnabled) +} + var DefaultInterfaceBlacklist = []string{ iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", "Tailscale", "tailscale", "docker", "veth", "br-", "lo", @@ -198,7 +210,7 @@ func getConfigDirForUser(username string) (string, error) { configDir := filepath.Join(DefaultConfigPathDir, username) if _, err := os.Stat(configDir); os.IsNotExist(err) { - if err := os.MkdirAll(configDir, 0600); err != nil { + if err := os.MkdirAll(configDir, 0700); err != nil { return "", err } } @@ -206,9 +218,15 @@ func getConfigDirForUser(username string) (string, error) { return configDir, nil } -func fileExists(path string) bool { +func fileExists(path string) (bool, error) { _, err := os.Stat(path) - return !os.IsNotExist(err) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -635,7 +653,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { - if !fileExists(input.ConfigPath) { + configExists, err := fileExists(input.ConfigPath) + if err != nil { + return nil, fmt.Errorf("failed to check if config file exists: %w", err) + } + if !configExists { return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath) } @@ -644,7 +666,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) { // UpdateOrCreateConfig reads existing config or generates a new one func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !fileExists(input.ConfigPath) { + configExists, err := fileExists(input.ConfigPath) + if err != nil { + return nil, fmt.Errorf("failed to check if config file exists: %w", err) + } + if !configExists { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { @@ -657,7 +683,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if isPreSharedKeyHidden(input.PreSharedKey) { input.PreSharedKey = nil } - err := util.EnforcePermission(input.ConfigPath) + err = util.EnforcePermission(input.ConfigPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) } @@ -739,21 +765,19 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return config, err } - client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled) + client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled) if err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return config, err } defer func() { - err = client.Close() - if err != nil { + if err := client.Close(); err != nil { log.Warnf("failed to close the Management service client %v", err) } }() // gRPC check - _, err = client.GetServerPublicKey() - if err != nil { + if err = client.HealthCheck(); err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return nil, err } @@ -784,7 +808,12 @@ func ReadConfig(configPath string) (*Config, error) { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func readConfig(configPath string, createIfMissing bool) (*Config, error) { - if fileExists(configPath) { + configExists, err := fileExists(configPath) + if err != nil { + return nil, fmt.Errorf("failed to check if config file exists: %w", err) + } + + if configExists { err := util.EnforcePermission(configPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) @@ -831,7 +860,11 @@ func DirectWriteOutConfig(path string, config *Config) error { // DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes. // Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !fileExists(input.ConfigPath) { + configExists, err := fileExists(input.ConfigPath) + if err != nil { + return nil, fmt.Errorf("failed to check if config file exists: %w", err) + } + if !configExists { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index ab13cf389..5216f2423 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -10,12 +10,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) +type mockMgmProber struct{} + +func (m *mockMgmProber) HealthCheck() error { + return nil +} + +func (m *mockMgmProber) Close() error { return nil } + func TestGetConfig(t *testing.T) { // case 1: new default config has to be generated config, err := UpdateOrCreateConfig(ConfigInput{ @@ -234,6 +243,12 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) { } func TestUpdateOldManagementURL(t *testing.T) { + origProber := newMgmProber + newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) { + return &mockMgmProber{}, nil + } + t.Cleanup(func() { newMgmProber = origProber }) + tests := []struct { name string previousManagementURL string @@ -273,18 +288,17 @@ func TestUpdateOldManagementURL(t *testing.T) { ConfigPath: configPath, }) require.NoError(t, err, "failed to create testing config") - previousStats, err := os.Stat(configPath) - require.NoError(t, err, "failed to create testing config stats") + previousContent, err := os.ReadFile(configPath) + require.NoError(t, err, "failed to read initial config") resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath) require.NoError(t, err, "got error when updating old management url") require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String()) - newStats, err := os.Stat(configPath) - require.NoError(t, err, "failed to create testing config stats") - switch tt.fileShouldNotChange { - case true: - require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change") - case false: - require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed") + newContent, err := os.ReadFile(configPath) + require.NoError(t, err, "failed to read updated config") + if tt.fileShouldNotChange { + require.Equal(t, string(previousContent), string(newContent), "file should not change") + } else { + require.NotEqual(t, string(previousContent), string(newContent), "file should have changed") } }) } diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go index bdb722c67..ef3eb1114 100644 --- a/client/internal/profilemanager/service.go +++ b/client/internal/profilemanager/service.go @@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error { } profPath := filepath.Join(configDir, profileName+".json") - if fileExists(profPath) { + profileExists, err := fileExists(profPath) + if err != nil { + return fmt.Errorf("failed to check if profile exists: %w", err) + } + if profileExists { return ErrProfileAlreadyExists } @@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error { return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName) } profPath := filepath.Join(configDir, profileName+".json") - if !fileExists(profPath) { + profileExists, err := fileExists(profPath) + if err != nil { + return fmt.Errorf("failed to check if profile exists: %w", err) + } + if !profileExists { return ErrProfileNotFound } diff --git a/client/internal/profilemanager/state.go b/client/internal/profilemanager/state.go index f84cb1032..f09391ede 100644 --- a/client/internal/profilemanager/state.go +++ b/client/internal/profilemanager/state.go @@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er } stateFile := filepath.Join(configDir, profileName+".state.json") - if !fileExists(stateFile) { + stateFileExists, err := fileExists(stateFile) + if err != nil { + return nil, fmt.Errorf("failed to check if profile state file exists: %w", err) + } + if !stateFileExists { return nil, errors.New("profile state file does not exist") } diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index 0b8e161d2..e6ef8b876 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,7 +3,9 @@ package client import ( "context" "fmt" + "net" "reflect" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -263,8 +265,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe case <-closer: return case routerStates := <-subscription.Events(): - peerStateUpdate <- routerStates - log.Debugf("triggered route state update for Peer: %s", peerKey) + select { + case peerStateUpdate <- routerStates: + log.Debugf("triggered route state update for Peer: %s", peerKey) + case <-ctx.Done(): + return + case <-closer: + return + } } } } @@ -558,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()) + dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 4bf0d5476..64f2a8789 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "net" "net/netip" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) + upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10)) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9afe2049d..3923e153b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -52,6 +52,7 @@ type Manager interface { TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap + GetSelectedClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -167,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { NetworkType: route.IPv4Network, } cr = append(cr, fakeIPRoute) + m.notifier.SetFakeIPRoute(fakeIPRoute) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) @@ -465,6 +467,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap { return maps.Clone(m.clientRoutes) } +// GetSelectedClientRoutes returns only the currently selected/active client routes, +// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking +// if traffic should be routed through the tunnel. +func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 6b06144b2..66b5e30dd 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -18,6 +18,7 @@ type MockManager struct { TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap + GetSelectedClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } -// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface func (m *MockManager) GetClientRoutes() route.HAMap { if m.GetClientRoutesFunc != nil { return m.GetClientRoutesFunc() @@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap { return nil } +// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface +func (m *MockManager) GetSelectedClientRoutes() route.HAMap { + if m.GetSelectedClientRoutesFunc != nil { + return m.GetSelectedClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil { diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go index dec0af87c..55e0b7421 100644 --- a/client/internal/routemanager/notifier/notifier_android.go +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -16,6 +16,7 @@ import ( type Notifier struct { initialRoutes []*route.Route currentRoutes []*route.Route + fakeIPRoute *route.Route listener listener.NetworkChangeListener listenerMux sync.Mutex @@ -31,26 +32,15 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { n.listener = listener } +// SetInitialClientRoutes stores the initial route sets for TUN configuration. func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) { - // initialRoutes contains fake IP block for interface configuration - filteredInitial := make([]*route.Route, 0) - for _, r := range initialRoutes { - if r.IsDynamic() { - continue - } - filteredInitial = append(filteredInitial, r) - } - n.initialRoutes = filteredInitial + n.initialRoutes = filterStatic(initialRoutes) + n.currentRoutes = filterStatic(routesForComparison) +} - // routesForComparison excludes fake IP block for comparison with new routes - filteredComparison := make([]*route.Route, 0) - for _, r := range routesForComparison { - if r.IsDynamic() { - continue - } - filteredComparison = append(filteredComparison, r) - } - n.currentRoutes = filteredComparison +// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild. +func (n *Notifier) SetFakeIPRoute(r *route.Route) { + n.fakeIPRoute = r } func (n *Notifier) OnNewRoutes(idMap route.HAMap) { @@ -83,13 +73,28 @@ func (n *Notifier) notify() { return } - routeStrings := n.routesToStrings(n.currentRoutes) + allRoutes := slices.Clone(n.currentRoutes) + if n.fakeIPRoute != nil { + allRoutes = append(allRoutes, n.fakeIPRoute) + } + + routeStrings := n.routesToStrings(allRoutes) sort.Strings(routeStrings) go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ",")) + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ",")) }(n.listener) } +func filterStatic(routes []*route.Route) []*route.Route { + out := make([]*route.Route, 0, len(routes)) + for _, r := range routes { + if !r.IsDynamic() { + out = append(out, r) + } + } + return out +} + func (n *Notifier) routesToStrings(routes []*route.Route) []string { nets := make([]string, 0, len(routes)) for _, r := range routes { diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go index bb125cfa4..68c85067a 100644 --- a/client/internal/routemanager/notifier/notifier_ios.go +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // iOS doesn't care about initial routes } +func (n *Notifier) SetFakeIPRoute(*route.Route) { + // Not used on iOS +} + func (n *Notifier) OnNewRoutes(route.HAMap) { // Not used on iOS } @@ -53,7 +57,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { n.currentPrefixes = newNets n.notify() } - func (n *Notifier) notify() { n.listenerMux.Lock() defer n.listenerMux.Unlock() diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 0521e3dc2..97c815cf0 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // Not used on non-mobile platforms } +func (n *Notifier) SetFakeIPRoute(*route.Route) { + // Not used on non-mobile platforms +} + func (n *Notifier) OnNewRoutes(idMap route.HAMap) { // Not used on non-mobile platforms } diff --git a/client/internal/routemanager/systemops/systemops_bsd_other.go b/client/internal/routemanager/systemops/systemops_bsd_other.go new file mode 100644 index 000000000..3f09219aa --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_bsd_other.go @@ -0,0 +1,10 @@ +//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin + +package systemops + +// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They +// always fall through to the ref-counter exclusion-route path; these stubs +// exist only so systemops_unix.go compiles. +func (r *SysOps) setupAdvancedRouting() error { return nil } +func (r *SysOps) cleanupAdvancedRouting() error { return nil } +func (r *SysOps) flushPlatformExtras() error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_darwin.go new file mode 100644 index 000000000..d6875ff95 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_darwin.go @@ -0,0 +1,241 @@ +//go:build darwin && !ios + +package systemops + +import ( + "errors" + "fmt" + "net/netip" + "os" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" +) + +// scopedRouteBudget bounds retries for the scoped default route. Installing or +// deleting it matters enough that we're willing to spend longer waiting for the +// kernel reply than for per-prefix exclusion routes. +const scopedRouteBudget = 5 * time.Second + +// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family +// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can +// resolve gateway'd destinations while the VPN's split default owns the +// unscoped table. +// +// Timing note: this runs during routeManager.Init, which happens before the +// VPN interface is created and before any peer routes propagate. The initial +// mgmt / signal / relay TCP dials always fire before this runs, so those +// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route +// lookup, which at that point correctly picks the physical default. Those +// already-established TCP flows keep their originally-selected interface for +// their lifetime on Darwin because the kernel caches the egress route +// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default +// afterwards does not migrate them since the original en0 default stays in +// the table. Any subsequent reconnect via nbnet.NewDialer picks up the +// populated bound-iface cache and gets IP_BOUND_IF set cleanly. +func (r *SysOps) setupAdvancedRouting() error { + // Drop any previously-cached egress interface before reinstalling. On a + // refresh, a family that no longer resolves would otherwise keep the stale + // binding, causing new sockets to scope to an interface without a matching + // scoped default. + nbnet.ClearBoundInterfaces() + + if err := r.flushScopedDefaults(); err != nil { + log.Warnf("flush residual scoped defaults: %v", err) + } + + var merr *multierror.Error + installed := 0 + + for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} { + ok, err := r.installScopedDefaultFor(unspec) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + if ok { + installed++ + } + } + + if installed == 0 && merr != nil { + return nberrors.FormatErrorOrNil(merr) + } + if merr != nil { + log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr)) + } + return nil +} + +// installScopedDefaultFor resolves the physical default nexthop for the given +// address family, installs a scoped default via it, and caches the iface for +// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds. +func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { + nexthop, err := GetNextHop(unspec) + if err != nil { + if errors.Is(err, vars.ErrRouteNotFound) { + return false, nil + } + return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err) + } + if nexthop.Intf == nil { + return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) + } + + if err := r.addScopedDefault(unspec, nexthop); err != nil { + return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + } + + af := unix.AF_INET + if unspec.Is6() { + af = unix.AF_INET6 + } + nbnet.SetBoundInterface(af, nexthop.Intf) + via := "point-to-point" + if nexthop.IP.IsValid() { + via = nexthop.IP.String() + } + log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec)) + return true, nil +} + +func (r *SysOps) cleanupAdvancedRouting() error { + nbnet.ClearBoundInterfaces() + return r.flushScopedDefaults() +} + +// flushPlatformExtras runs darwin-specific residual cleanup hooked into the +// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get +// removed on the next boot regardless of whether a profile is brought up. +func (r *SysOps) flushPlatformExtras() error { + return r.flushScopedDefaults() +} + +// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag. +// Safe to call at startup to clear residual entries from a prior session. +func (r *SysOps) flushScopedDefaults() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + removed := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + if rtMsg.Flags&unix.RTF_IFSCOPE == 0 { + continue + } + + info, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("skip scoped flush: %v", err) + continue + } + if !info.Dst.IsValid() || info.Dst.Bits() != 0 { + continue + } + + if err := r.deleteScopedRoute(rtMsg); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w", + info.Dst, rtMsg.Index, err)) + continue + } + removed++ + log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index) + } + + if removed > 0 { + log.Infof("flushed %d residual scoped default route(s)", removed) + } + return nberrors.FormatErrorOrNil(merr) +} + +func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error { + return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop) +} + +func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error { + // Preserve identifying flags from the stored route (including RTF_GATEWAY + // only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE. + keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag + del := &route.RouteMessage{ + Type: unix.RTM_DELETE, + Flags: rtMsg.Flags & keep, + Version: unix.RTM_VERSION, + Seq: r.getSeq(), + Index: rtMsg.Index, + Addrs: rtMsg.Addrs, + } + return r.writeRouteMessage(del, scopedRouteBudget) +} + +func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error { + flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag + + msg := &route.RouteMessage{ + Type: action, + Flags: flags, + Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), + Seq: r.getSeq(), + Index: nexthop.Intf.Index, + } + + const numAddrs = unix.RTAX_NETMASK + 1 + addrs := make([]route.Addr, numAddrs) + + dst, err := addrToRouteAddr(unspec) + if err != nil { + return fmt.Errorf("build destination: %w", err) + } + mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0)) + if err != nil { + return fmt.Errorf("build netmask: %w", err) + } + addrs[unix.RTAX_DST] = dst + addrs[unix.RTAX_NETMASK] = mask + + if nexthop.IP.IsValid() { + msg.Flags |= unix.RTF_GATEWAY + gw, err := addrToRouteAddr(nexthop.IP.Unmap()) + if err != nil { + return fmt.Errorf("build gateway: %w", err) + } + addrs[unix.RTAX_GATEWAY] = gw + } else { + addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{ + Index: nexthop.Intf.Index, + Name: nexthop.Intf.Name, + } + } + msg.Addrs = addrs + + return r.writeRouteMessage(msg, scopedRouteBudget) +} + +func afOf(a netip.Addr) string { + if a.Is4() { + return "IPv4" + } + return "IPv6" +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index ec219c7fe..4211eb057 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/net/hooks" ) @@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -var ErrRoutingIsSeparate = errors.New("routing is separate") - func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { stateManager.RegisterState(&ShutdownState{}) @@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { } // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. +// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux, +// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped. func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { - localRoutes, err := hasSeparateRouting() + if nbnet.AdvancedRouting() { + return false, netip.Prefix{} + } + + localRoutes, err := GetRoutesFromTable() if err != nil { - if !errors.Is(err, ErrRoutingIsSeparate) { - log.Errorf("Failed to get routes: %v", err) - } + log.Errorf("Failed to get routes: %v", err) return false, netip.Prefix{} } diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go index 808507fc9..242571b3d 100644 --- a/client/internal/routemanager/systemops/systemops_js.go +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return []netip.Prefix{}, nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return []netip.Prefix{}, nil -} - // GetDetailedRoutesFromTable returns empty routes for WASM. func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { return []DetailedRoute{}, nil diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f..39a9fd978 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int { return netlink.FAMILY_V6 } -func hasSeparateRouting() ([]netip.Prefix, error) { - if !nbnet.AdvancedRouting() { - return GetRoutesFromTable() - } - return nil, ErrRoutingIsSeparate -} - func isOpErr(err error) bool { // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 905a7bc12..016a62ebd 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -48,10 +48,6 @@ func EnableIPForwarding() error { return nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return GetRoutesFromTable() -} - // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) func GetIPRules() ([]IPRule, error) { log.Infof("IP rules collection is not supported on %s", runtime.GOOS) diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 7089178fb..2d3f9b69a 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -25,6 +25,9 @@ import ( const ( envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" + + // routeBudget bounds retries for per-prefix exclusion route programming. + routeBudget = 1 * time.Second ) var routeProtoFlag int @@ -41,26 +44,42 @@ func init() { } func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.setupAdvancedRouting() + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.cleanupAdvancedRouting() + } + return r.cleanupRefCounter(stateManager) } // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a +// crashed prior session can't leave crud in the table. func (r *SysOps) FlushMarkedRoutes() error { + var merr *multierror.Error + + if err := r.flushPlatformExtras(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err)) + } + rib, err := retryFetchRIB() if err != nil { - return fmt.Errorf("fetch routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err))) } msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) if err != nil { - return fmt.Errorf("parse routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err))) } - var merr *multierror.Error flushedCount := 0 for _, msg := range msgs { @@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return fmt.Errorf("invalid prefix: %s", prefix) } - expBackOff := backoff.NewExponentialBackOff() - expBackOff.InitialInterval = 50 * time.Millisecond - expBackOff.MaxInterval = 500 * time.Millisecond - expBackOff.MaxElapsedTime = 1 * time.Second + msg, err := r.buildRouteMessage(action, prefix, nexthop) + if err != nil { + return fmt.Errorf("build route message: %w", err) + } - if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { + if err := r.writeRouteMessage(msg, routeBudget); err != nil { a := "add" if action == unix.RTM_DELETE { a = "remove" @@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return nil } -func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { - operation := func() error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) - if err != nil { - return fmt.Errorf("open routing socket: %w", err) +// writeRouteMessage sends a route message over AF_ROUTE and waits for the +// kernel's matching reply, retrying transient failures until budget elapses. +// Callers do not need to manage sockets or seq numbers themselves. +func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error { + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = budget + + return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff) +} + +func routeMessageRoundtrip(msg *route.RouteMessage) error { + fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + if err != nil { + return fmt.Errorf("open routing socket: %w", err) + } + defer func() { + if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { + log.Warnf("close routing socket: %v", err) } - defer func() { - if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { - log.Warnf("failed to close routing socket: %v", err) + }() + + tv := unix.Timeval{Sec: 1} + if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err)) + } + + // AF_ROUTE is a broadcast channel: every route socket on the host sees + // every RTM_* event. With concurrent route programming the default + // per-socket queue overflows and our own reply gets dropped. + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil { + log.Debugf("set SO_RCVBUF on route socket: %v", err) + } + + bytes, err := msg.Marshal() + if err != nil { + return backoff.Permanent(fmt.Errorf("marshal: %w", err)) + } + + if _, err = unix.Write(fd, bytes); err != nil { + if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { + return fmt.Errorf("write: %w", err) + } + return backoff.Permanent(fmt.Errorf("write: %w", err)) + } + return readRouteResponse(fd, msg.Type, msg.Seq) +} + +// readRouteResponse reads from the AF_ROUTE socket until it sees a reply +// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a +// broadcast channel: interface up/down, third-party route changes and neighbor +// discovery events can all land between our write and read, so we must filter. +func readRouteResponse(fd, wantType, wantSeq int) error { + pid := int32(os.Getpid()) + resp := make([]byte, 2048) + deadline := time.Now().Add(time.Second) + for { + if time.Now().After(deadline) { + // Transient: under concurrent pressure the kernel can drop our reply + // from the socket buffer. Let backoff.Retry re-send with a fresh seq. + return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq) + } + n, err := unix.Read(fd, resp) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) { + // SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline. + continue } - }() - - msg, err := r.buildRouteMessage(action, prefix, nexthop) - if err != nil { - return backoff.Permanent(fmt.Errorf("build route message: %w", err)) + return backoff.Permanent(fmt.Errorf("read: %w", err)) } - - msgBytes, err := msg.Marshal() - if err != nil { - return backoff.Permanent(fmt.Errorf("marshal route message: %w", err)) + if n < int(unsafe.Sizeof(unix.RtMsghdr{})) { + continue } - - if _, err = unix.Write(fd, msgBytes); err != nil { - if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { - return fmt.Errorf("write: %w", err) - } - return backoff.Permanent(fmt.Errorf("write: %w", err)) + hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0])) + // Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid) + // uniquely identifies our own reply among broadcast traffic. + if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid { + continue } - - respBuf := make([]byte, 2048) - n, err := unix.Read(fd, respBuf) - if err != nil { - return backoff.Permanent(fmt.Errorf("read route response: %w", err)) + if hdr.Errno != 0 { + return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno))) } - - if n > 0 { - if err := r.parseRouteResponse(respBuf[:n]); err != nil { - return backoff.Permanent(err) - } - } - return nil } - return operation } func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { @@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next Type: action, Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), Seq: r.getSeq(), } @@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next return msg, nil } -func (r *SysOps) parseRouteResponse(buf []byte) error { - if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) { - return nil - } - - rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - if rtMsg.Errno != 0 { - return fmt.Errorf("parse: %d", rtMsg.Errno) - } - - return nil -} - // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { if addr.Is4() { diff --git a/client/internal/sleep/handler/handler.go b/client/internal/sleep/handler/handler.go new file mode 100644 index 000000000..9c2c5d4d5 --- /dev/null +++ b/client/internal/sleep/handler/handler.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" +) + +type Agent interface { + Up(ctx context.Context) error + Down(ctx context.Context) error + Status() (internal.StatusType, error) +} + +type SleepHandler struct { + agent Agent + + mu sync.Mutex + // sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake + sleepTriggeredDown bool +} + +func New(agent Agent) *SleepHandler { + return &SleepHandler{ + agent: agent, + } +} + +func (s *SleepHandler) HandleWakeUp(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.sleepTriggeredDown { + log.Info("skipping up because wasn't sleep down") + return nil + } + + // avoid other wakeup runs if sleep didn't make the computer sleep + s.sleepTriggeredDown = false + + log.Info("running up after wake up") + err := s.agent.Up(ctx) + if err != nil { + log.Errorf("running up failed: %v", err) + return err + } + + log.Info("running up command executed successfully") + return nil +} + +func (s *SleepHandler) HandleSleep(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + status, err := s.agent.Status() + if err != nil { + return err + } + + if status != internal.StatusConnecting && status != internal.StatusConnected { + log.Infof("skipping setting the agent down because status is %s", status) + return nil + } + + log.Info("running down after system started sleeping") + + if err = s.agent.Down(ctx); err != nil { + log.Errorf("running down failed: %v", err) + return err + } + + s.sleepTriggeredDown = true + + log.Info("running down executed successfully") + return nil +} diff --git a/client/internal/sleep/handler/handler_test.go b/client/internal/sleep/handler/handler_test.go new file mode 100644 index 000000000..9f79428fb --- /dev/null +++ b/client/internal/sleep/handler/handler_test.go @@ -0,0 +1,153 @@ +package handler + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" +) + +type mockAgent struct { + upErr error + downErr error + statusErr error + status internal.StatusType + upCalls int +} + +func (m *mockAgent) Up(_ context.Context) error { + m.upCalls++ + return m.upErr +} + +func (m *mockAgent) Down(_ context.Context) error { + return m.downErr +} + +func (m *mockAgent) Status() (internal.StatusType, error) { + return m.status, m.statusErr +} + +func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) { + agent := &mockAgent{status: status} + return New(agent), agent +} + +func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) { + h, agent := newHandler(internal.StatusIdle) + + err := h.HandleWakeUp(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false") +} + +func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) { + h, _ := newHandler(internal.StatusIdle) + h.sleepTriggeredDown = true + + // Even if Up fails, flag should be reset + _ = h.HandleWakeUp(context.Background()) + + assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up") +} + +func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) { + h, agent := newHandler(internal.StatusIdle) + h.sleepTriggeredDown = true + + err := h.HandleWakeUp(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 1, agent.upCalls) + assert.False(t, h.sleepTriggeredDown) +} + +func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) { + h, agent := newHandler(internal.StatusIdle) + h.sleepTriggeredDown = true + agent.upErr = errors.New("up failed") + + err := h.HandleWakeUp(context.Background()) + + assert.ErrorIs(t, err, agent.upErr) + assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails") +} + +func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) { + h, agent := newHandler(internal.StatusIdle) + h.sleepTriggeredDown = true + + _ = h.HandleWakeUp(context.Background()) + err := h.HandleWakeUp(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op") +} + +func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Idle", internal.StatusIdle}, + {"NeedsLogin", internal.StatusNeedsLogin}, + {"LoginFailed", internal.StatusLoginFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, _ := newHandler(tt.status) + + err := h.HandleSleep(context.Background()) + + require.NoError(t, err) + assert.False(t, h.sleepTriggeredDown) + }) + } +} + +func TestHandleSleep_ProceedsForActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Connecting", internal.StatusConnecting}, + {"Connected", internal.StatusConnected}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, _ := newHandler(tt.status) + + err := h.HandleSleep(context.Background()) + + require.NoError(t, err) + assert.True(t, h.sleepTriggeredDown) + }) + } +} + +func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) { + agent := &mockAgent{statusErr: errors.New("status error")} + h := New(agent) + + err := h.HandleSleep(context.Background()) + + assert.ErrorIs(t, err, agent.statusErr) + assert.False(t, h.sleepTriggeredDown) +} + +func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) { + agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")} + h := New(agent) + + err := h.HandleSleep(context.Background()) + + assert.ErrorIs(t, err, agent.downErr) + assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails") +} diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go deleted file mode 100644 index 20ddec10d..000000000 --- a/client/internal/updatemanager/manager_test.go +++ /dev/null @@ -1,214 +0,0 @@ -//go:build windows || darwin - -package updatemanager - -import ( - "context" - "fmt" - "path" - "testing" - "time" - - v "github.com/hashicorp/go-version" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -type versionUpdateMock struct { - latestVersion *v.Version - onUpdate func() -} - -func (v versionUpdateMock) StopWatch() {} - -func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { - return false -} - -func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { - v.onUpdate = updateFn -} - -func (v versionUpdateMock) LatestVersion() *v.Version { - return v.latestVersion -} - -func (v versionUpdateMock) StartFetcher() {} - -func Test_LatestVersion(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - initialLatestVersion *v.Version - latestVersion *v.Version - shouldUpdateInit bool - shouldUpdateLater bool - }{ - { - name: "Should only trigger update once due to time between triggers being < 5 Minutes", - daemonVersion: "1.0.0", - initialLatestVersion: v.Must(v.NewSemver("1.0.1")), - latestVersion: v.Must(v.NewSemver("1.0.2")), - shouldUpdateInit: true, - shouldUpdateLater: false, - }, - { - name: "Shouldn't update initially, but should update as soon as latest version is fetched", - daemonVersion: "1.0.0", - initialLatestVersion: nil, - latestVersion: v.Must(v.NewSemver("1.0.1")), - shouldUpdateInit: false, - shouldUpdateLater: true, - }, - } - - for idx, c := range testMatrix { - mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = mockUpdate - - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion("latest") - var triggeredInit bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.initialLatestVersion.String() { - t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) - } - triggeredInit = true - case <-time.After(10 * time.Millisecond): - triggeredInit = false - } - if triggeredInit != c.shouldUpdateInit { - t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) - } - - mockUpdate.latestVersion = c.latestVersion - mockUpdate.onUpdate() - - var triggeredLater bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } - triggeredLater = true - case <-time.After(10 * time.Millisecond): - triggeredLater = false - } - if triggeredLater != c.shouldUpdateLater { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) - } - - m.Stop() - } -} - -func Test_HandleUpdate(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - latestVersion *v.Version - expectedVersion string - shouldUpdate bool - }{ - { - name: "Update to a specific version should update regardless of if latestVersion is available yet", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.56.0", - shouldUpdate: true, - }, - { - name: "Update to specific version should not update if version matches", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.55.0", - shouldUpdate: false, - }, - { - name: "Update to specific version should not update if current version is newer", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.54.0", - shouldUpdate: false, - }, - { - name: "Update to latest version should update if latest is newer", - daemonVersion: "0.55.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: true, - }, - { - name: "Update to latest version should not update if latest == current", - daemonVersion: "0.56.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if daemon version is invalid", - daemonVersion: "development", - latestVersion: v.Must(v.NewSemver("1.0.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expecting latest and latest version is unavailable", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expected version is invalid", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "development", - shouldUpdate: false, - }, - } - for idx, c := range testMatrix { - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = &versionUpdateMock{latestVersion: c.latestVersion} - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion(c.expectedVersion) - - var updateTriggered bool - select { - case targetVersion := <-targetVersionChan: - if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) - } - updateTriggered = true - case <-time.After(10 * time.Millisecond): - updateTriggered = false - } - - if updateTriggered != c.shouldUpdate { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) - } - m.Stop() - } -} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go deleted file mode 100644 index 4e87c2d77..000000000 --- a/client/internal/updatemanager/manager_unsupported.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build !windows && !darwin - -package updatemanager - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -// Manager is a no-op stub for unsupported platforms -type Manager struct{} - -// NewManager returns a no-op manager for unsupported platforms -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - return nil, fmt.Errorf("update manager is not supported on this platform") -} - -// CheckUpdateSuccess is a no-op on unsupported platforms -func (m *Manager) CheckUpdateSuccess(ctx context.Context) { - // no-op -} - -// Start is a no-op on unsupported platforms -func (m *Manager) Start(ctx context.Context) { - // no-op -} - -// SetVersion is a no-op on unsupported platforms -func (m *Manager) SetVersion(expectedVersion string) { - // no-op -} - -// Stop is a no-op on unsupported platforms -func (m *Manager) Stop() { - // no-op -} diff --git a/client/internal/updatemanager/doc.go b/client/internal/updater/doc.go similarity index 93% rename from client/internal/updatemanager/doc.go rename to client/internal/updater/doc.go index 54d1bdeab..e1924aa43 100644 --- a/client/internal/updatemanager/doc.go +++ b/client/internal/updater/doc.go @@ -1,4 +1,4 @@ -// Package updatemanager provides automatic update management for the NetBird client. +// Package updater provides automatic update management for the NetBird client. // It monitors for new versions, handles update triggers from management server directives, // and orchestrates the download and installation of client updates. // @@ -32,4 +32,4 @@ // // This enables verification of successful updates and appropriate user notification // after the client restarts with the new version. -package updatemanager +package updater diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updater/downloader/downloader.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader.go rename to client/internal/updater/downloader/downloader.go diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updater/downloader/downloader_test.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader_test.go rename to client/internal/updater/downloader/downloader_test.go diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updater/installer/binary_nowindows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_nowindows.go rename to client/internal/updater/installer/binary_nowindows.go diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updater/installer/binary_windows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_windows.go rename to client/internal/updater/installer/binary_windows.go diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updater/installer/doc.go similarity index 100% rename from client/internal/updatemanager/installer/doc.go rename to client/internal/updater/installer/doc.go diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updater/installer/installer.go similarity index 100% rename from client/internal/updatemanager/installer/installer.go rename to client/internal/updater/installer/installer.go diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updater/installer/installer_common.go similarity index 97% rename from client/internal/updatemanager/installer/installer_common.go rename to client/internal/updater/installer/installer_common.go index 03378d55f..8e44bee82 100644 --- a/client/internal/updatemanager/installer/installer_common.go +++ b/client/internal/updater/installer/installer_common.go @@ -16,8 +16,8 @@ import ( goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/downloader" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) type Installer struct { diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updater/installer/installer_log_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_darwin.go rename to client/internal/updater/installer/installer_log_darwin.go diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updater/installer/installer_log_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_windows.go rename to client/internal/updater/installer/installer_log_windows.go diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updater/installer/installer_run_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_darwin.go rename to client/internal/updater/installer/installer_run_darwin.go diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updater/installer/installer_run_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_windows.go rename to client/internal/updater/installer/installer_run_windows.go diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updater/installer/log.go similarity index 100% rename from client/internal/updatemanager/installer/log.go rename to client/internal/updater/installer/log.go diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updater/installer/procattr_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_darwin.go rename to client/internal/updater/installer/procattr_darwin.go diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updater/installer/procattr_windows.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_windows.go rename to client/internal/updater/installer/procattr_windows.go diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updater/installer/repourl_dev.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_dev.go rename to client/internal/updater/installer/repourl_dev.go diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updater/installer/repourl_prod.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_prod.go rename to client/internal/updater/installer/repourl_prod.go diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updater/installer/result.go similarity index 98% rename from client/internal/updatemanager/installer/result.go rename to client/internal/updater/installer/result.go index 03d08d527..526c3eb53 100644 --- a/client/internal/updatemanager/installer/result.go +++ b/client/internal/updater/installer/result.go @@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error { func (rh *ResultHandler) cleanup() error { err := os.Remove(rh.resultFile) - if err != nil && !os.IsNotExist(err) { + if err != nil { + if os.IsNotExist(err) { + return nil + } return err } log.Debugf("delete installer result file: %s", rh.resultFile) diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updater/installer/types.go similarity index 100% rename from client/internal/updatemanager/installer/types.go rename to client/internal/updater/installer/types.go diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updater/installer/types_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/types_darwin.go rename to client/internal/updater/installer/types_darwin.go diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updater/installer/types_windows.go similarity index 100% rename from client/internal/updatemanager/installer/types_windows.go rename to client/internal/updater/installer/types_windows.go diff --git a/client/internal/updatemanager/manager.go b/client/internal/updater/manager.go similarity index 52% rename from client/internal/updatemanager/manager.go rename to client/internal/updater/manager.go index eae11de56..dfcb93177 100644 --- a/client/internal/updatemanager/manager.go +++ b/client/internal/updater/manager.go @@ -1,12 +1,9 @@ -//go:build windows || darwin - -package updatemanager +package updater import ( "context" "errors" "fmt" - "runtime" "sync" "time" @@ -15,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -41,6 +38,9 @@ type Manager struct { statusRecorder *peer.Status stateManager *statemanager.Manager + downloadOnly bool // true when no enforcement from management; notifies UI to download latest + forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly + lastTrigger time.Time mgmUpdateChan chan struct{} updateChannel chan struct{} @@ -53,37 +53,38 @@ type Manager struct { expectedVersion *v.Version updateToLatestVersion bool - // updateMutex protect update and expectedVersion fields + pendingVersion *v.Version + + // updateMutex protects update, expectedVersion, updateToLatestVersion, + // downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields updateMutex sync.Mutex - triggerUpdateFn func(context.Context, string) error + // installMutex and installing guard against concurrent installation attempts + installMutex sync.Mutex + installing bool + + // protect to start the service multiple times + mu sync.Mutex + + autoUpdateSupported func() bool } -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - if runtime.GOOS == "darwin" { - isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() - if isBrew { - log.Warnf("auto-update disabled on Home Brew installation") - return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") - } - } - return newManager(statusRecorder, stateManager) -} - -func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { +// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted. +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager { manager := &Manager{ - statusRecorder: statusRecorder, - stateManager: stateManager, - mgmUpdateChan: make(chan struct{}, 1), - updateChannel: make(chan struct{}, 1), - currentVersion: version.NetbirdVersion(), - update: version.NewUpdate("nb/client"), + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + downloadOnly: true, + autoUpdateSupported: isAutoUpdateSupported, } - manager.triggerUpdateFn = manager.triggerUpdate stateManager.RegisterState(&UpdateState{}) - return manager, nil + return manager } // CheckUpdateSuccess checks if the update was successful and send a notification. @@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) { } func (m *Manager) Start(ctx context.Context) { + log.Infof("starting update manager") + m.mu.Lock() + defer m.mu.Unlock() if m.cancel != nil { - log.Errorf("Manager already started") return } @@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) { m.cancel = cancel m.wg.Add(1) - go m.updateLoop(ctx) + go func() { + defer m.wg.Done() + m.updateLoop(ctx) + }() } -func (m *Manager) SetVersion(expectedVersion string) { - log.Infof("set expected agent version for upgrade: %s", expectedVersion) - if m.cancel == nil { - log.Errorf("manager not started") +func (m *Manager) SetDownloadOnly() { + m.updateMutex.Lock() + m.downloadOnly = true + m.forceUpdate = false + m.expectedVersion = nil + m.updateToLatestVersion = false + m.lastTrigger = time.Time{} + m.updateMutex.Unlock() + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) { + log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate) + + if !m.autoUpdateSupported() { + log.Warnf("auto-update not supported on this platform") return } @@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) { log.Errorf("empty expected version provided") m.expectedVersion = nil m.updateToLatestVersion = false + m.downloadOnly = true return } @@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) { m.updateToLatestVersion = false } + m.lastTrigger = time.Time{} + m.downloadOnly = false + m.forceUpdate = forceUpdate + select { case m.mgmUpdateChan <- struct{}{}: default: } } +// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI. +func (m *Manager) Install(ctx context.Context) error { + if !m.autoUpdateSupported() { + return fmt.Errorf("auto-update not supported on this platform") + } + + m.updateMutex.Lock() + pending := m.pendingVersion + m.updateMutex.Unlock() + + if pending == nil { + return fmt.Errorf("no pending version to install") + } + + return m.tryInstall(ctx, pending) +} + +// tryInstall ensures only one installation runs at a time. Concurrent callers +// receive an error immediately rather than queuing behind a running install. +func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error { + m.installMutex.Lock() + if m.installing { + m.installMutex.Unlock() + return fmt.Errorf("installation already in progress") + } + m.installing = true + m.installMutex.Unlock() + + defer func() { + m.installMutex.Lock() + m.installing = false + m.installMutex.Unlock() + }() + + return m.install(ctx, targetVersion) +} + +// NotifyUI re-publishes the current update state to a newly connected UI client. +// Only needed for download-only mode where the latest version is already cached +// NotifyUI re-publishes the current update state so a newly connected UI gets the info. +func (m *Manager) NotifyUI() { + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + downloadOnly := m.downloadOnly + pendingVersion := m.pendingVersion + latestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + if downloadOnly { + if latestVersion == nil { + return + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) { + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": latestVersion.String()}, + ) + return + } + + if pendingVersion != nil { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"}, + ) + } +} + +// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it. func (m *Manager) Stop() { if m.cancel == nil { return @@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() { } func (m *Manager) updateLoop(ctx context.Context) { - defer m.wg.Done() - for { select { case <-ctx.Done(): @@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) { return } - expectedVersion := m.expectedVersion - useLatest := m.updateToLatestVersion + downloadOnly := m.downloadOnly + forceUpdate := m.forceUpdate curLatestVersion := m.update.LatestVersion() - m.updateMutex.Unlock() switch { - // Resolve "latest" to actual version - case useLatest: + // Download-only mode or resolve "latest" to actual version + case downloadOnly, m.updateToLatestVersion: if curLatestVersion == nil { log.Tracef("latest version not fetched yet") + m.updateMutex.Unlock() return } updateVersion = curLatestVersion - // Update to specific version - case expectedVersion != nil: - updateVersion = expectedVersion + // Install to specific version + case m.expectedVersion != nil: + updateVersion = m.expectedVersion default: log.Debugf("no expected version information set") + m.updateMutex.Unlock() return } log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) - if !m.shouldUpdate(updateVersion) { + if !m.shouldUpdate(updateVersion, forceUpdate) { + m.updateMutex.Unlock() return } m.lastTrigger = time.Now() - log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) - m.statusRecorder.PublishEvent( - cProto.SystemEvent_CRITICAL, - cProto.SystemEvent_SYSTEM, - "Automatically updating client", - "Your client version is older than auto-update version set in Management, updating client now.", - nil, - ) + log.Infof("new version available: %s", updateVersion) + + if !downloadOnly && !forceUpdate { + m.pendingVersion = updateVersion + } + m.updateMutex.Unlock() + + if downloadOnly { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String()}, + ) + return + } + + if forceUpdate { + if err := m.tryInstall(ctx, updateVersion); err != nil { + log.Errorf("force update failed: %v", err) + } + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"}, + ) +} + +func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Updating client", + "Installing update now.", + nil, + ) m.statusRecorder.PublishEvent( cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM, "", "", - map[string]string{"progress_window": "show", "version": updateVersion.String()}, + map[string]string{"progress_window": "show", "version": pendingVersion.String()}, ) updateState := UpdateState{ PreUpdateVersion: m.currentVersion, - TargetVersion: updateVersion.String(), + TargetVersion: pendingVersion.String(), } - if err := m.stateManager.UpdateState(updateState); err != nil { log.Warnf("failed to update state: %v", err) } else { @@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { } } - if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { - log.Errorf("Error triggering auto-update: %v", err) + inst := installer.New() + if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil { + log.Errorf("error triggering update: %v", err) m.statusRecorder.PublishEvent( cProto.SystemEvent_ERROR, cProto.SystemEvent_SYSTEM, @@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { fmt.Sprintf("Auto-update failed: %v", err), nil, ) + return err } + return nil } // loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. @@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e return updateState, nil } -func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { +func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool { if m.currentVersion == developmentVersion { log.Debugf("skipping auto-update, running development version") return false @@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { return false } - if time.Since(m.lastTrigger) < 5*time.Minute { - log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute { + log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) return false } @@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string { result := installer.NewResultHandler(inst.TempDir()) return result.GetErrorResultReason() } - -func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { - inst := installer.New() - return inst.RunInstallation(ctx, targetVersion) -} diff --git a/client/internal/updater/manager_linux_test.go b/client/internal/updater/manager_linux_test.go new file mode 100644 index 000000000..b05dd7e7d --- /dev/null +++ b/client/internal/updater/manager_linux_test.go @@ -0,0 +1,111 @@ +//go:build !windows && !darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// On Linux, only Mode 1 (downloadOnly) is supported. +// SetVersion is a no-op because auto-update installation is not supported. + +func Test_LatestVersion_Linux(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't notify initially, but should notify as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_SetVersion_NoOp_Linux(t *testing.T) { + // On Linux, SetVersion should be a no-op — no events fired + tmpFile := path.Join(t.TempDir(), "update-test-noop.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetVersion("1.0.1", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver != "" { + t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver) + } + + m.Stop() +} diff --git a/client/internal/updater/manager_test.go b/client/internal/updater/manager_test.go new file mode 100644 index 000000000..107dca2b3 --- /dev/null +++ b/client/internal/updater/manager_test.go @@ -0,0 +1,227 @@ +//go:build windows || darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + cProto "github.com/netbirdio/netbird/client/proto" +) + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion("latest", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Install to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Install to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Install to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Install to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Install to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion(c.expectedVersion, false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + updateTriggered := ver != "" + + if updateTriggered { + if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver) + } + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} + +func Test_EnforcedMetadata(t *testing.T) { + // Mode 1 (downloadOnly): no enforced metadata + tmpFile := path.Join(t.TempDir(), "update-test-mode1.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 1: expected new_version_available event") + } + if enforced { + t.Error("Mode 1: expected no enforced metadata") + } + m.Stop() + + // Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install + tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json") + recorder2 := peer.NewRecorder("") + sub2 := recorder2.SubscribeToEvents() + defer recorder2.UnsubscribeFromEvents(sub2) + + m2 := NewManager(recorder2, statemanager.New(tmpFile2)) + m2.update = &versionUpdateMock{latestVersion: nil} + m2.currentVersion = "1.0.0" + m2.autoUpdateSupported = func() bool { return true } + m2.Start(context.Background()) + m2.SetVersion("1.0.1", false) + + ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 2: expected new_version_available event") + } + if !enforced2 { + t.Error("Mode 2: expected enforced metadata") + } + m2.Stop() +} + +// ensure the proto import is used +var _ = cProto.SystemEvent_INFO diff --git a/client/internal/updater/manager_test_helpers_test.go b/client/internal/updater/manager_test_helpers_test.go new file mode 100644 index 000000000..c7faee1f4 --- /dev/null +++ b/client/internal/updater/manager_test_helpers_test.go @@ -0,0 +1,56 @@ +package updater + +import ( + "strconv" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (m versionUpdateMock) StopWatch() {} + +func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + m.onUpdate = updateFn +} + +func (m versionUpdateMock) LatestVersion() *v.Version { + return m.latestVersion +} + +func (m versionUpdateMock) StartFetcher() {} + +// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout. +func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) { + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case event, ok := <-sub.Events(): + if !ok { + return "", false + } + if val, ok := event.Metadata["new_version_available"]; ok { + enforced := false + if raw, ok := event.Metadata["enforced"]; ok { + if parsed, err := strconv.ParseBool(raw); err == nil { + enforced = parsed + } + } + return val, enforced + } + case <-timer.C: + return "", false + } + } +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updater/reposign/artifact.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact.go rename to client/internal/updater/reposign/artifact.go diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updater/reposign/artifact_test.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact_test.go rename to client/internal/updater/reposign/artifact_test.go diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updater/reposign/certs/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certs/root-pub.pem rename to client/internal/updater/reposign/certs/root-pub.pem diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updater/reposign/certsdev/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certsdev/root-pub.pem rename to client/internal/updater/reposign/certsdev/root-pub.pem diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updater/reposign/doc.go similarity index 100% rename from client/internal/updatemanager/reposign/doc.go rename to client/internal/updater/reposign/doc.go diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updater/reposign/embed_dev.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_dev.go rename to client/internal/updater/reposign/embed_dev.go diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updater/reposign/embed_prod.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_prod.go rename to client/internal/updater/reposign/embed_prod.go diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updater/reposign/key.go similarity index 100% rename from client/internal/updatemanager/reposign/key.go rename to client/internal/updater/reposign/key.go diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updater/reposign/key_test.go similarity index 100% rename from client/internal/updatemanager/reposign/key_test.go rename to client/internal/updater/reposign/key_test.go diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updater/reposign/revocation.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation.go rename to client/internal/updater/reposign/revocation.go diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updater/reposign/revocation_test.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation_test.go rename to client/internal/updater/reposign/revocation_test.go diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updater/reposign/root.go similarity index 100% rename from client/internal/updatemanager/reposign/root.go rename to client/internal/updater/reposign/root.go diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updater/reposign/root_test.go similarity index 100% rename from client/internal/updatemanager/reposign/root_test.go rename to client/internal/updater/reposign/root_test.go diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updater/reposign/signature.go similarity index 100% rename from client/internal/updatemanager/reposign/signature.go rename to client/internal/updater/reposign/signature.go diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updater/reposign/signature_test.go similarity index 100% rename from client/internal/updatemanager/reposign/signature_test.go rename to client/internal/updater/reposign/signature_test.go diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updater/reposign/verify.go similarity index 98% rename from client/internal/updatemanager/reposign/verify.go rename to client/internal/updater/reposign/verify.go index 0af2a8c9e..f64b26a30 100644 --- a/client/internal/updatemanager/reposign/verify.go +++ b/client/internal/updater/reposign/verify.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updater/downloader" ) const ( diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updater/reposign/verify_test.go similarity index 100% rename from client/internal/updatemanager/reposign/verify_test.go rename to client/internal/updater/reposign/verify_test.go diff --git a/client/internal/updater/supported_darwin.go b/client/internal/updater/supported_darwin.go new file mode 100644 index 000000000..b27754366 --- /dev/null +++ b/client/internal/updater/supported_darwin.go @@ -0,0 +1,22 @@ +package updater + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updater/installer" +) + +func isAutoUpdateSupported() bool { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + isBrew := !installer.TypeOfInstaller(ctx).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Homebrew installation") + return false + } + return true +} diff --git a/client/internal/updater/supported_other.go b/client/internal/updater/supported_other.go new file mode 100644 index 000000000..e09e8c3a3 --- /dev/null +++ b/client/internal/updater/supported_other.go @@ -0,0 +1,7 @@ +//go:build !windows && !darwin + +package updater + +func isAutoUpdateSupported() bool { + return false +} diff --git a/client/internal/updater/supported_windows.go b/client/internal/updater/supported_windows.go new file mode 100644 index 000000000..0c28878c7 --- /dev/null +++ b/client/internal/updater/supported_windows.go @@ -0,0 +1,5 @@ +package updater + +func isAutoUpdateSupported() bool { + return true +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updater/update.go similarity index 90% rename from client/internal/updatemanager/update.go rename to client/internal/updater/update.go index 875b50b49..3056c77e1 100644 --- a/client/internal/updatemanager/update.go +++ b/client/internal/updater/update.go @@ -1,4 +1,4 @@ -package updatemanager +package updater import v "github.com/hashicorp/go-version" diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index aafef41d3..043673904 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -160,8 +160,12 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + hostDNS := []netip.AddrPort{ + netip.MustParseAddrPort("9.9.9.9:53"), + netip.MustParseAddrPort("149.112.112.112:53"), + } + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/net/dialer_init_darwin.go b/client/net/dialer_init_darwin.go new file mode 100644 index 000000000..e18909ff7 --- /dev/null +++ b/client/net/dialer_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyBoundIfToSocket +} diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go index 18ebc6ad1..78973b47d 100644 --- a/client/net/dialer_init_generic.go +++ b/client/net/dialer_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_android.go b/client/net/env_android.go deleted file mode 100644 index 9d89951a1..000000000 --- a/client/net/env_android.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build android - -package net - -// Init initializes the network environment for Android -func Init() { - // No initialization needed on Android -} - -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. -// Always returns true on Android since we cannot handle routes dynamically. -func AdvancedRouting() bool { - return true -} - -// SetVPNInterfaceName is a no-op on Android -func SetVPNInterfaceName(name string) { - // No-op on Android - not needed for Android VPN service -} - -// GetVPNInterfaceName returns empty string on Android -func GetVPNInterfaceName() string { - return "" -} diff --git a/client/net/env_windows.go b/client/net/env_bound_iface.go similarity index 71% rename from client/net/env_windows.go rename to client/net/env_bound_iface.go index 7e8868ba5..593988c2c 100644 --- a/client/net/env_windows.go +++ b/client/net/env_bound_iface.go @@ -1,4 +1,4 @@ -//go:build windows +//go:build (darwin && !ios) || windows package net @@ -24,17 +24,22 @@ func Init() { } func checkAdvancedRoutingSupport() bool { - var err error - var legacyRouting bool + legacyRouting := false if val := os.Getenv(envUseLegacyRouting); val != "" { - legacyRouting, err = strconv.ParseBool(val) + parsed, err := strconv.ParseBool(val) if err != nil { - log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err) + } else { + legacyRouting = parsed } } - if legacyRouting || netstack.IsEnabled() { - log.Info("advanced routing has been requested to be disabled") + if legacyRouting { + log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting) + return false + } + if netstack.IsEnabled() { + log.Info("advanced routing disabled: netstack mode is enabled") return false } diff --git a/client/net/env_generic.go b/client/net/env_generic.go index f467930c3..18c10bb78 100644 --- a/client/net/env_generic.go +++ b/client/net/env_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows && !android +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_mobile.go b/client/net/env_mobile.go new file mode 100644 index 000000000..80b0fad8d --- /dev/null +++ b/client/net/env_mobile.go @@ -0,0 +1,25 @@ +//go:build ios || android + +package net + +// Init initializes the network environment for mobile platforms. +func Init() { + // no-op on mobile: routing scope is owned by the VPN extension. +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension +// owns the routing scope. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on mobile. +func SetVPNInterfaceName(string) { + // no-op on mobile: the VPN extension manages the interface. +} + +// GetVPNInterfaceName returns an empty string on mobile. +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/listener_init_darwin.go b/client/net/listener_init_darwin.go new file mode 100644 index 000000000..f2fcc80ed --- /dev/null +++ b/client/net/listener_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (l *ListenerConfig) init() { + l.ListenConfig.Control = applyBoundIfToSocket +} diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go index 4f8f17ab2..65a785222 100644 --- a/client/net/listener_init_generic.go +++ b/client/net/listener_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/net_darwin.go b/client/net/net_darwin.go new file mode 100644 index 000000000..00d858a6a --- /dev/null +++ b/client/net/net_darwin.go @@ -0,0 +1,160 @@ +package net + +import ( + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "sync" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack +// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6" +// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns +// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is +// dispatched by socket domain (AF_INET only) not by inp_vflag. + +// boundIface holds the physical interface chosen at routing setup time. Sockets +// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF +// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup +// hits the RTF_IFSCOPE default installed by the routemanager, rather than +// following the VPN's split default. +var ( + boundIfaceMu sync.RWMutex + boundIface4 *net.Interface + boundIface6 *net.Interface +) + +// SetBoundInterface records the egress interface for an address family. Called +// by the routemanager after a scoped default route has been installed. +// af must be unix.AF_INET or unix.AF_INET6; other values are ignored. +// nil iface is rejected — use ClearBoundInterfaces to clear all slots. +func SetBoundInterface(af int, iface *net.Interface) { + if iface == nil { + log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af) + return + } + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + switch af { + case unix.AF_INET: + boundIface4 = iface + case unix.AF_INET6: + boundIface6 = iface + default: + log.Warnf("SetBoundInterface: unsupported address family %d", af) + } +} + +// ClearBoundInterfaces resets the cached egress interfaces. Called by the +// routemanager during cleanup. +func ClearBoundInterfaces() { + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + boundIface4 = nil + boundIface6 = nil +} + +// boundInterfaceFor returns the cached egress interface for a socket's address +// family, falling back to the other family if the preferred slot is empty. +// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so +// either setsockopt scopes the socket; preferring same-family still matters +// when v4 and v6 defaults egress different NICs. +func boundInterfaceFor(network, address string) *net.Interface { + if iface := zoneInterface(address); iface != nil { + return iface + } + + boundIfaceMu.RLock() + defer boundIfaceMu.RUnlock() + + primary, secondary := boundIface4, boundIface6 + if isV6Network(network) { + primary, secondary = boundIface6, boundIface4 + } + if primary != nil { + return primary + } + return secondary +} + +func isV6Network(network string) bool { + return strings.HasSuffix(network, "6") +} + +// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0). +func zoneInterface(address string) *net.Interface { + if address == "" { + return nil + } + addr, err := netip.ParseAddrPort(address) + if err != nil { + a, err := netip.ParseAddr(address) + if err != nil { + return nil + } + addr = netip.AddrPortFrom(a, 0) + } + zone := addr.Addr().Zone() + if zone == "" { + return nil + } + if iface, err := net.InterfaceByName(zone); err == nil { + return iface + } + if idx, err := strconv.Atoi(zone); err == nil { + if iface, err := net.InterfaceByIndex(idx); err == nil { + return iface + } + } + return nil +} + +func setIPv4BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +// applyBoundIfToSocket binds the socket to the cached physical egress interface +// so scoped route lookup avoids the VPN utun and egresses the underlay directly. +func applyBoundIfToSocket(network, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + iface := boundInterfaceFor(network, address) + if iface == nil { + log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address) + return nil + } + + isV6 := isV6Network(network) + var controlErr error + if err := c.Control(func(fd uintptr) { + if isV6 { + controlErr = setIPv6BoundIf(fd, iface) + } else { + controlErr = setIPv4BoundIf(fd, iface) + } + if controlErr == nil { + log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address) + } + }); err != nil { + return fmt.Errorf("control: %w", err) + } + return controlErr +} diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 7c9fa021a..0e330bdac 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -1,12 +1,10 @@ #!/usr/bin/env bash set -eEuo pipefail -: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} +: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="30"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() -log_file_path="" _log() { # mimic Go logger's output for easier parsing @@ -33,60 +31,29 @@ on_exit() { fi } -wait_for_message() { - local timeout="${1}" message="${2}" - if test "${timeout}" -eq 0; then - info "not waiting for log line ${message@Q} due to zero timeout." - elif test -n "${log_file_path}"; then - info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) - else - info "log file unsupported, sleeping for ${timeout} seconds..." - sleep "${timeout}" - fi -} - -locate_log_file() { - local log_files_string="${1}" - - while read -r log_file; do - case "${log_file}" in - console | syslog) ;; - *) - log_file_path="${log_file}" - return - ;; - esac - done < <(sed 's#,#\n#g' <<<"${log_files_string}") - - warn "log files parsing for ${log_files_string@Q} is not supported by debug bundles" - warn "please consider removing the \$NB_LOG_FILE or setting it to real file, before gathering debug bundles." -} - wait_for_daemon_startup() { local timeout="${1}" - - if test -n "${log_file_path}"; then - if ! wait_for_message "${timeout}" "started daemon server"; then - warn "log line containing 'started daemon server' not found after ${timeout} seconds" - warn "daemon failed to start, exiting..." - exit 1 - fi - else - warn "daemon service startup not discovered, sleeping ${timeout} instead" - sleep "${timeout}" + if [[ "${timeout}" -eq 0 ]]; then + info "not waiting for daemon startup due to zero timeout." + return fi + + local deadline=$((SECONDS + timeout)) + while [[ "${SECONDS}" -lt "${deadline}" ]]; do + if "${NETBIRD_BIN}" status --check live 2>/dev/null; then + return + fi + sleep 1 + done + + warn "daemon did not become responsive after ${timeout} seconds, exiting..." + exit 1 } -login_if_needed() { - local timeout="${1}" - - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then - info "already logged in, skipping 'netbird up'..." - else - info "logging in..." - "${NETBIRD_BIN}" up - fi +connect() { + info "running 'netbird up'..." + "${NETBIRD_BIN}" up + return $? } main() { @@ -95,9 +62,8 @@ main() { service_pids+=("$!") info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}" - locate_log_file "${NB_LOG_FILE}" wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}" - login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}" + connect wait "${service_pids[@]}" } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 1d9d7233c..6506307d3 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v6.33.1 // source: daemon.proto package proto @@ -88,6 +88,61 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +type ExposeProtocol int32 + +const ( + ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0 + ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 + ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 + ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 +) + +// Enum value maps for ExposeProtocol. +var ( + ExposeProtocol_name = map[int32]string{ + 0: "EXPOSE_HTTP", + 1: "EXPOSE_HTTPS", + 2: "EXPOSE_TCP", + 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", + } + ExposeProtocol_value = map[string]int32{ + "EXPOSE_HTTP": 0, + "EXPOSE_HTTPS": 1, + "EXPOSE_TCP": 2, + "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, + } +) + +func (x ExposeProtocol) Enum() *ExposeProtocol { + p := new(ExposeProtocol) + *p = x + return p +} + +func (x ExposeProtocol) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor { + return file_daemon_proto_enumTypes[1].Descriptor() +} + +func (ExposeProtocol) Type() protoreflect.EnumType { + return &file_daemon_proto_enumTypes[1] +} + +func (x ExposeProtocol) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use ExposeProtocol.Descriptor instead. +func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1} +} + // avoid collision with loglevel enum type OSLifecycleRequest_CycleType int32 @@ -122,11 +177,11 @@ func (x OSLifecycleRequest_CycleType) String() string { } func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[1].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[1] + return &file_daemon_proto_enumTypes[2] } func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { @@ -174,11 +229,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -229,11 +284,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[3].Descriptor() + return file_daemon_proto_enumTypes[4].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[3] + return &file_daemon_proto_enumTypes[4] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -893,7 +948,6 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` - AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -942,13 +996,6 @@ func (x *UpRequest) GetUsername() string { return "" } -func (x *UpRequest) GetAutoUpdate() bool { - if x != nil && x.AutoUpdate != nil { - return *x.AutoUpdate - } - return false -} - type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -4932,6 +4979,7 @@ type GetFeaturesResponse struct { state protoimpl.MessageState `protogen:"open.v1"` DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"` DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"` + DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -4980,6 +5028,101 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +func (x *GetFeaturesResponse) GetDisableNetworks() bool { + if x != nil { + return x.DisableNetworks + } + return false +} + +type TriggerUpdateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateRequest) Reset() { + *x = TriggerUpdateRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateRequest) ProtoMessage() {} + +func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead. +func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +type TriggerUpdateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateResponse) Reset() { + *x = TriggerUpdateResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateResponse) ProtoMessage() {} + +func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead. +func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *TriggerUpdateResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *TriggerUpdateResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer type GetPeerSSHHostKeyRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -4991,7 +5134,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5003,7 +5146,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5016,7 +5159,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -5043,7 +5186,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5055,7 +5198,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5068,7 +5211,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -5110,7 +5253,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5122,7 +5265,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5135,7 +5278,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5168,7 +5311,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5180,7 +5323,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5193,7 +5336,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5258,7 +5401,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5270,7 +5413,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5283,7 +5426,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{77} + return file_daemon_proto_rawDescGZIP(), []int{79} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5315,7 +5458,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5327,7 +5470,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5340,7 +5483,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{78} + return file_daemon_proto_rawDescGZIP(), []int{80} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5373,7 +5516,7 @@ type StartCPUProfileRequest struct { func (x *StartCPUProfileRequest) Reset() { *x = StartCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5385,7 +5528,7 @@ func (x *StartCPUProfileRequest) String() string { func (*StartCPUProfileRequest) ProtoMessage() {} func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5398,7 +5541,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{81} } // StartCPUProfileResponse confirms CPU profiling has started @@ -5410,7 +5553,7 @@ type StartCPUProfileResponse struct { func (x *StartCPUProfileResponse) Reset() { *x = StartCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5422,7 +5565,7 @@ func (x *StartCPUProfileResponse) String() string { func (*StartCPUProfileResponse) ProtoMessage() {} func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5435,7 +5578,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{82} } // StopCPUProfileRequest for stopping CPU profiling @@ -5447,7 +5590,7 @@ type StopCPUProfileRequest struct { func (x *StopCPUProfileRequest) Reset() { *x = StopCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5459,7 +5602,7 @@ func (x *StopCPUProfileRequest) String() string { func (*StopCPUProfileRequest) ProtoMessage() {} func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5472,7 +5615,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{81} + return file_daemon_proto_rawDescGZIP(), []int{83} } // StopCPUProfileResponse confirms CPU profiling has stopped @@ -5484,7 +5627,7 @@ type StopCPUProfileResponse struct { func (x *StopCPUProfileResponse) Reset() { *x = StopCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5496,7 +5639,7 @@ func (x *StopCPUProfileResponse) String() string { func (*StopCPUProfileResponse) ProtoMessage() {} func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5509,7 +5652,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{82} + return file_daemon_proto_rawDescGZIP(), []int{84} } type InstallerResultRequest struct { @@ -5520,7 +5663,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5532,7 +5675,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5545,7 +5688,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{83} + return file_daemon_proto_rawDescGZIP(), []int{85} } type InstallerResultResponse struct { @@ -5558,7 +5701,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5570,7 +5713,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5583,7 +5726,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{84} + return file_daemon_proto_rawDescGZIP(), []int{86} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5600,6 +5743,240 @@ func (x *InstallerResultResponse) GetErrorMsg() string { return "" } +type ExposeServiceRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"` + Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"` + Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"` + Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"` + UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` + Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` + NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExposeServiceRequest) Reset() { + *x = ExposeServiceRequest{} + mi := &file_daemon_proto_msgTypes[87] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExposeServiceRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExposeServiceRequest) ProtoMessage() {} + +func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[87] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. +func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{87} +} + +func (x *ExposeServiceRequest) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol { + if x != nil { + return x.Protocol + } + return ExposeProtocol_EXPOSE_HTTP +} + +func (x *ExposeServiceRequest) GetPin() string { + if x != nil { + return x.Pin + } + return "" +} + +func (x *ExposeServiceRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *ExposeServiceRequest) GetUserGroups() []string { + if x != nil { + return x.UserGroups + } + return nil +} + +func (x *ExposeServiceRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *ExposeServiceRequest) GetNamePrefix() string { + if x != nil { + return x.NamePrefix + } + return "" +} + +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + +type ExposeServiceEvent struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Event: + // + // *ExposeServiceEvent_Ready + Event isExposeServiceEvent_Event `protobuf_oneof:"event"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExposeServiceEvent) Reset() { + *x = ExposeServiceEvent{} + mi := &file_daemon_proto_msgTypes[88] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExposeServiceEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExposeServiceEvent) ProtoMessage() {} + +func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[88] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead. +func (*ExposeServiceEvent) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{88} +} + +func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event { + if x != nil { + return x.Event + } + return nil +} + +func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady { + if x != nil { + if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok { + return x.Ready + } + } + return nil +} + +type isExposeServiceEvent_Event interface { + isExposeServiceEvent_Event() +} + +type ExposeServiceEvent_Ready struct { + Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"` +} + +func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {} + +type ExposeServiceReady struct { + state protoimpl.MessageState `protogen:"open.v1"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExposeServiceReady) Reset() { + *x = ExposeServiceReady{} + mi := &file_daemon_proto_msgTypes[89] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExposeServiceReady) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExposeServiceReady) ProtoMessage() {} + +func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[89] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead. +func (*ExposeServiceReady) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{89} +} + +func (x *ExposeServiceReady) GetServiceName() string { + if x != nil { + return x.ServiceName + } + return "" +} + +func (x *ExposeServiceReady) GetServiceUrl() string { + if x != nil { + return x.ServiceUrl + } + return "" +} + +func (x *ExposeServiceReady) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *ExposeServiceReady) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5610,7 +5987,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[91] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5622,7 +5999,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[91] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5746,16 +6123,12 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"v\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + - "\n" + - "autoUpdate\x18\x03 \x01(\bH\x02R\n" + - "autoUpdate\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_usernameB\r\n" + - "\v_autoUpdate\"\f\n" + + "\t_usernameJ\x04\b\x03\x10\x04\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -6107,10 +6480,15 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\x10\n" + "\x0eLogoutResponse\"\x14\n" + - "\x12GetFeaturesRequest\"x\n" + + "\x12GetFeaturesRequest\"\xa3\x01\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" + + "\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" + + "\x14TriggerUpdateRequest\"M\n" + + "\x15TriggerUpdateResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"<\n" + "\x18GetPeerSSHHostKeyRequest\x12 \n" + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + @@ -6149,7 +6527,28 @@ const file_daemon_proto_rawDesc = "" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + - "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\x87\x02\n" + + "\x14ExposeServiceRequest\x12\x12\n" + + "\x04port\x18\x01 \x01(\rR\x04port\x122\n" + + "\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" + + "\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" + + "\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" + + "\vuser_groups\x18\x05 \x03(\tR\n" + + "userGroups\x12\x16\n" + + "\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" + + "\vname_prefix\x18\a \x01(\tR\n" + + "namePrefix\x12\x1f\n" + + "\vlisten_port\x18\b \x01(\rR\n" + + "listenPort\"Q\n" + + "\x12ExposeServiceEvent\x122\n" + + "\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" + + "\x05event\"\x9e\x01\n" + + "\x12ExposeServiceReady\x12!\n" + + "\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" + + "\vservice_url\x18\x02 \x01(\tR\n" + + "serviceUrl\x12\x16\n" + + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6158,7 +6557,16 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xdd\x14\n" + + "\x05TRACE\x10\a*c\n" + + "\x0eExposeProtocol\x12\x0f\n" + + "\vEXPOSE_HTTP\x10\x00\x12\x10\n" + + "\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" + + "\n" + + "EXPOSE_TCP\x10\x02\x12\x0e\n" + + "\n" + + "EXPOSE_UDP\x10\x03\x12\x0e\n" + + "\n" + + "EXPOSE_TLS\x10\x042\xfc\x15\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6190,14 +6598,16 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12N\n" + + "\rTriggerUpdate\x12\x1c.daemon.TriggerUpdateRequest\x1a\x1d.daemon.TriggerUpdateResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + - "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" + + "\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -6211,214 +6621,226 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel - (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType - (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 4: daemon.EmptyRequest - (*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest - (*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse - (*LoginRequest)(nil), // 7: daemon.LoginRequest - (*LoginResponse)(nil), // 8: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 11: daemon.UpRequest - (*UpResponse)(nil), // 12: daemon.UpResponse - (*StatusRequest)(nil), // 13: daemon.StatusRequest - (*StatusResponse)(nil), // 14: daemon.StatusResponse - (*DownRequest)(nil), // 15: daemon.DownRequest - (*DownResponse)(nil), // 16: daemon.DownResponse - (*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse - (*PeerState)(nil), // 19: daemon.PeerState - (*LocalPeerState)(nil), // 20: daemon.LocalPeerState - (*SignalState)(nil), // 21: daemon.SignalState - (*ManagementState)(nil), // 22: daemon.ManagementState - (*RelayState)(nil), // 23: daemon.RelayState - (*NSGroupState)(nil), // 24: daemon.NSGroupState - (*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo - (*SSHServerState)(nil), // 26: daemon.SSHServerState - (*FullStatus)(nil), // 27: daemon.FullStatus - (*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse - (*IPList)(nil), // 32: daemon.IPList - (*Network)(nil), // 33: daemon.Network - (*PortInfo)(nil), // 34: daemon.PortInfo - (*ForwardingRule)(nil), // 35: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse - (*State)(nil), // 43: daemon.State - (*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 52: daemon.TCPFlags - (*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest - (*TraceStage)(nil), // 54: daemon.TraceStage - (*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest - (*SystemEvent)(nil), // 57: daemon.SystemEvent - (*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse - (*Profile)(nil), // 70: daemon.Profile - (*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 73: daemon.LogoutRequest - (*LogoutResponse)(nil), // 74: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse - (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse - (*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest - (*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse - (*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest - (*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse - (*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse - nil, // 89: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range - nil, // 91: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 92: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp + (ExposeProtocol)(0), // 1: daemon.ExposeProtocol + (OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType + (SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 5: daemon.EmptyRequest + (*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest + (*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse + (*LoginRequest)(nil), // 8: daemon.LoginRequest + (*LoginResponse)(nil), // 9: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 12: daemon.UpRequest + (*UpResponse)(nil), // 13: daemon.UpResponse + (*StatusRequest)(nil), // 14: daemon.StatusRequest + (*StatusResponse)(nil), // 15: daemon.StatusResponse + (*DownRequest)(nil), // 16: daemon.DownRequest + (*DownResponse)(nil), // 17: daemon.DownResponse + (*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse + (*PeerState)(nil), // 20: daemon.PeerState + (*LocalPeerState)(nil), // 21: daemon.LocalPeerState + (*SignalState)(nil), // 22: daemon.SignalState + (*ManagementState)(nil), // 23: daemon.ManagementState + (*RelayState)(nil), // 24: daemon.RelayState + (*NSGroupState)(nil), // 25: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 27: daemon.SSHServerState + (*FullStatus)(nil), // 28: daemon.FullStatus + (*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse + (*IPList)(nil), // 33: daemon.IPList + (*Network)(nil), // 34: daemon.Network + (*PortInfo)(nil), // 35: daemon.PortInfo + (*ForwardingRule)(nil), // 36: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse + (*State)(nil), // 44: daemon.State + (*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 53: daemon.TCPFlags + (*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest + (*TraceStage)(nil), // 55: daemon.TraceStage + (*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest + (*SystemEvent)(nil), // 58: daemon.SystemEvent + (*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse + (*Profile)(nil), // 71: daemon.Profile + (*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 74: daemon.LogoutRequest + (*LogoutResponse)(nil), // 75: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse + (*TriggerUpdateRequest)(nil), // 78: daemon.TriggerUpdateRequest + (*TriggerUpdateResponse)(nil), // 79: daemon.TriggerUpdateResponse + (*GetPeerSSHHostKeyRequest)(nil), // 80: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 81: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 82: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 83: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 84: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 85: daemon.WaitJWTTokenResponse + (*StartCPUProfileRequest)(nil), // 86: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 87: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 88: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 89: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 90: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 91: daemon.InstallerResultResponse + (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest + (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent + (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady + nil, // 95: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range + nil, // 97: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 98: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState - 23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState - 24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType + 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState + 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState + 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State - 52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 69, // [69:104] is the sub-list for method output_type - 34, // [34:69] is the sub-list for method input_type - 34, // [34:34] is the sub-list for extension type_name - 34, // [34:34] is the sub-list for extension extendee - 0, // [0:34] is the sub-list for field type_name + 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State + 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol + 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 73, // [73:110] is the sub-list for method output_type + 36, // [36:73] is the sub-list for method input_type + 36, // [36:36] is the sub-list for extension type_name + 36, // [36:36] is the sub-list for extension extendee + 0, // [0:36] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6438,14 +6860,17 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[56].OneofWrappers = []any{} file_daemon_proto_msgTypes[58].OneofWrappers = []any{} file_daemon_proto_msgTypes[69].OneofWrappers = []any{} - file_daemon_proto_msgTypes[75].OneofWrappers = []any{} + file_daemon_proto_msgTypes[77].OneofWrappers = []any{} + file_daemon_proto_msgTypes[88].OneofWrappers = []any{ + (*ExposeServiceEvent_Ready)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 4, - NumMessages: 88, + NumEnums: 5, + NumMessages: 93, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 68b9a9348..19976660c 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -85,6 +85,10 @@ service DaemonService { rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + rpc TriggerUpdate(TriggerUpdateRequest) returns (TriggerUpdateResponse) {} + // GetPeerSSHHostKey retrieves SSH host key for a specific peer rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} @@ -103,6 +107,9 @@ service DaemonService { rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} + + // ExposeService exposes a local port via the NetBird reverse proxy + rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {} } @@ -223,7 +230,7 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; - optional bool autoUpdate = 3; + reserved 3; } message UpResponse {} @@ -720,6 +727,14 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; + bool disable_networks = 3; +} + +message TriggerUpdateRequest {} + +message TriggerUpdateResponse { + bool success = 1; + string errorMsg = 2; } // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer @@ -801,3 +816,35 @@ message InstallerResultResponse { bool success = 1; string errorMsg = 2; } + +enum ExposeProtocol { + EXPOSE_HTTP = 0; + EXPOSE_HTTPS = 1; + EXPOSE_TCP = 2; + EXPOSE_UDP = 3; + EXPOSE_TLS = 4; +} + +message ExposeServiceRequest { + uint32 port = 1; + ExposeProtocol protocol = 2; + string pin = 3; + string password = 4; + repeated string user_groups = 5; + string domain = 6; + string name_prefix = 7; + uint32 listen_port = 8; +} + +message ExposeServiceEvent { + oneof event { + ExposeServiceReady ready = 1; + } +} + +message ExposeServiceReady { + string service_name = 1; + string service_url = 2; + string domain = 3; + bool port_auto_assigned = 4; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index ea9b4df05..e5bd89597 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -64,6 +64,9 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -76,6 +79,8 @@ type DaemonServiceClient interface { StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) + // ExposeService exposes a local port via the NetBird reverse proxy + ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) } type daemonServiceClient struct { @@ -361,6 +366,15 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) { + out := new(TriggerUpdateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TriggerUpdate", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { out := new(GetPeerSSHHostKeyResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) @@ -424,6 +438,38 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal return out, nil } +func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) { + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...) + if err != nil { + return nil, err + } + x := &daemonServiceExposeServiceClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type DaemonService_ExposeServiceClient interface { + Recv() (*ExposeServiceEvent, error) + grpc.ClientStream +} + +type daemonServiceExposeServiceClient struct { + grpc.ClientStream +} + +func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) { + m := new(ExposeServiceEvent) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -474,6 +520,9 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -486,6 +535,8 @@ type DaemonServiceServer interface { StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) + // ExposeService exposes a local port via the NetBird reverse proxy + ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error mustEmbedUnimplementedDaemonServiceServer() } @@ -577,6 +628,9 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TriggerUpdate not implemented") +} func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") } @@ -598,6 +652,9 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") } +func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error { + return status.Errorf(codes.Unimplemented, "method ExposeService not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1118,6 +1175,24 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_TriggerUpdate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TriggerUpdateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/TriggerUpdate", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, req.(*TriggerUpdateRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetPeerSSHHostKeyRequest) if err := dec(in); err != nil { @@ -1244,6 +1319,27 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont return interceptor(ctx, in, info, handler) } +func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(ExposeServiceRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream}) +} + +type DaemonService_ExposeServiceServer interface { + Send(*ExposeServiceEvent) error + grpc.ServerStream +} + +type daemonServiceExposeServiceServer struct { + grpc.ServerStream +} + +func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error { + return x.ServerStream.SendMsg(m) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1359,6 +1455,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "TriggerUpdate", + Handler: _DaemonService_TriggerUpdate_Handler, + }, { MethodName: "GetPeerSSHHostKey", Handler: _DaemonService_GetPeerSSHHostKey_Handler, @@ -1394,6 +1494,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ Handler: _DaemonService_SubscribeEvents_Handler, ServerStreams: true, }, + { + StreamName: "ExposeService", + Handler: _DaemonService_ExposeService_Handler, + ServerStreams: true, + }, }, Metadata: "daemon.proto", } diff --git a/client/server/debug.go b/client/server/debug.go index 4c531efba..81708e576 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -26,6 +26,15 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( log.Warnf("failed to get latest sync response: %v", err) } + var clientMetrics debug.MetricsExporter + if s.connectClient != nil { + if engine := s.connectClient.Engine(); engine != nil { + if cm := engine.GetClientMetrics(); cm != nil { + clientMetrics = cm + } + } + } + var cpuProfileData []byte if s.cpuProfileBuf != nil && !s.cpuProfiling { cpuProfileData = s.cpuProfileBuf.Bytes() @@ -54,6 +63,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( LogPath: s.logFile, CPUProfile: cpuProfileData, RefreshStatus: refreshStatus, + ClientMetrics: clientMetrics, }, debug.BundleConfig{ Anonymize: req.GetAnonymize(), diff --git a/client/server/event.go b/client/server/event.go index b5c12a3a6..d93151c96 100644 --- a/client/server/event.go +++ b/client/server/event.go @@ -14,6 +14,7 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo }() log.Debug("client subscribed to events") + s.startUpdateManagerForGUI() for { select { diff --git a/client/server/lifecycle.go b/client/server/lifecycle.go deleted file mode 100644 index 3722c027d..000000000 --- a/client/server/lifecycle.go +++ /dev/null @@ -1,77 +0,0 @@ -package server - -import ( - "context" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/proto" -) - -// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. -func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { - switch req.GetType() { - case proto.OSLifecycleRequest_WAKEUP: - return s.handleWakeUp(callerCtx) - case proto.OSLifecycleRequest_SLEEP: - return s.handleSleep(callerCtx) - default: - log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) - } - return &proto.OSLifecycleResponse{}, nil -} - -// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep. -// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails. -func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { - if !s.sleepTriggeredDown.Load() { - log.Info("skipping up because wasn't sleep down") - return &proto.OSLifecycleResponse{}, nil - } - - // avoid other wakeup runs if sleep didn't make the computer sleep - s.sleepTriggeredDown.Store(false) - - log.Info("running up after wake up") - _, err := s.Up(callerCtx, &proto.UpRequest{}) - if err != nil { - log.Errorf("running up failed: %v", err) - return &proto.OSLifecycleResponse{}, err - } - - log.Info("running up command executed successfully") - return &proto.OSLifecycleResponse{}, nil -} - -// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state. -func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { - s.mutex.Lock() - - state := internal.CtxGetState(s.rootCtx) - status, err := state.Status() - if err != nil { - s.mutex.Unlock() - return &proto.OSLifecycleResponse{}, err - } - - if status != internal.StatusConnecting && status != internal.StatusConnected { - log.Infof("skipping setting the agent down because status is %s", status) - s.mutex.Unlock() - return &proto.OSLifecycleResponse{}, nil - } - s.mutex.Unlock() - - log.Info("running down after system started sleeping") - - _, err = s.Down(callerCtx, &proto.DownRequest{}) - if err != nil { - log.Errorf("running down failed: %v", err) - return &proto.OSLifecycleResponse{}, err - } - - s.sleepTriggeredDown.Store(true) - - log.Info("running down executed successfully") - return &proto.OSLifecycleResponse{}, nil -} diff --git a/client/server/lifecycle_test.go b/client/server/lifecycle_test.go deleted file mode 100644 index a604c60af..000000000 --- a/client/server/lifecycle_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package server - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/proto" -) - -func newTestServer() *Server { - ctx := internal.CtxInitState(context.Background()) - return &Server{ - rootCtx: ctx, - statusRecorder: peer.NewRecorder(""), - } -} - -func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) { - s := newTestServer() - - // sleepTriggeredDown is false by default - assert.False(t, s.sleepTriggeredDown.Load()) - - resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_WAKEUP, - }) - - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false") -} - -func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) { - s := newTestServer() - - state := internal.CtxGetState(s.rootCtx) - state.Set(internal.StatusIdle) - - resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_SLEEP, - }) - - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle") -} - -func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) { - s := newTestServer() - - state := internal.CtxGetState(s.rootCtx) - state.Set(internal.StatusNeedsLogin) - - resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_SLEEP, - }) - - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin") -} - -func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) { - s := newTestServer() - - state := internal.CtxGetState(s.rootCtx) - state.Set(internal.StatusConnecting) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s.actCancel = cancel - - resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_SLEEP, - }) - - require.NoError(t, err) - assert.NotNil(t, resp, "handleSleep returns not nil response on success") - assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting") -} - -func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) { - s := newTestServer() - - state := internal.CtxGetState(s.rootCtx) - state.Set(internal.StatusConnected) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s.actCancel = cancel - - resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_SLEEP, - }) - - require.NoError(t, err) - assert.NotNil(t, resp, "handleSleep returns not nil response on success") - assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected") -} - -func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) { - s := newTestServer() - - // Manually set the flag to simulate prior sleep down - s.sleepTriggeredDown.Store(true) - - // WakeUp will try to call Up which fails without proper setup, but flag should reset first - _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_WAKEUP, - }) - - assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt") -} - -func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) { - s := newTestServer() - - // First wakeup without prior sleep - should be no-op - resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_WAKEUP, - }) - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load()) - - // Simulate prior sleep - s.sleepTriggeredDown.Store(true) - - // First wakeup after sleep - should reset flag - _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_WAKEUP, - }) - assert.False(t, s.sleepTriggeredDown.Load()) - - // Second wakeup - should be no-op - resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ - Type: proto.OSLifecycleRequest_WAKEUP, - }) - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load()) -} - -func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) { - s := newTestServer() - - resp, err := s.handleWakeUp(context.Background()) - - require.NoError(t, err) - require.NotNil(t, resp) -} - -func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) { - s := newTestServer() - s.sleepTriggeredDown.Store(true) - - // Even if Up fails, flag should be reset - _, _ = s.handleWakeUp(context.Background()) - - assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up") -} - -func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) { - tests := []struct { - name string - status internal.StatusType - }{ - {"Idle", internal.StatusIdle}, - {"NeedsLogin", internal.StatusNeedsLogin}, - {"LoginFailed", internal.StatusLoginFailed}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := newTestServer() - state := internal.CtxGetState(s.rootCtx) - state.Set(tt.status) - - resp, err := s.handleSleep(context.Background()) - - require.NoError(t, err) - require.NotNil(t, resp) - assert.False(t, s.sleepTriggeredDown.Load()) - }) - } -} - -func TestHandleSleep_ProceedsForActiveStates(t *testing.T) { - tests := []struct { - name string - status internal.StatusType - }{ - {"Connecting", internal.StatusConnecting}, - {"Connected", internal.StatusConnected}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := newTestServer() - state := internal.CtxGetState(s.rootCtx) - state.Set(tt.status) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s.actCancel = cancel - - resp, err := s.handleSleep(ctx) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, s.sleepTriggeredDown.Load()) - }) - } -} diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c..76c5af40e 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -9,6 +9,8 @@ import ( "strings" "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/server/server.go b/client/server/server.go index 108eab9fe..70e4c342f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -21,13 +21,17 @@ import ( gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/internal/profilemanager" + sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -49,6 +53,7 @@ const ( errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" + errNetworksDisabled = "network selection is disabled by the administrator" ) var ErrServiceNotUp = errors.New("service is not up") @@ -84,9 +89,11 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + networksDisabled bool - // sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down - sleepTriggeredDown atomic.Bool + sleepHandler *sleephandler.SleepHandler + + updateManager *updater.Manager jwtCache *jwtCache } @@ -99,8 +106,8 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { - return &Server{ +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server { + s := &Server{ rootCtx: ctx, logFile: logFile, persistSyncResponse: true, @@ -108,8 +115,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + networksDisabled: networksDisabled, jwtCache: newJWTCache(), } + agent := &serverAgent{s} + s.sleepHandler = sleephandler.New(agent) + + return s } func (s *Server) Start() error { @@ -130,6 +142,12 @@ func (s *Server) Start() error { log.Warnf(errRestoreResidualState, err) } + if s.updateManager == nil { + stateMgr := statemanager.New(s.profileManager.GetStatePath()) + s.updateManager = updater.NewManager(s.statusRecorder, stateMgr) + s.updateManager.CheckUpdateSuccess(s.rootCtx) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -187,14 +205,14 @@ func (s *Server) Start() error { s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -202,7 +220,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -231,8 +249,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) - doInitialAutoUpdate = false + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -636,8 +653,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR return s.waitForUp(callerCtx) } - defer s.mutex.Unlock() - if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } @@ -649,10 +664,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR // not in the progress or already successfully established connection. status, err := state.Status() if err != nil { + s.mutex.Unlock() return nil, err } if status != internal.StatusIdle { + s.mutex.Unlock() return nil, fmt.Errorf("up already in progress: current status %s", status) } @@ -669,17 +686,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.actCancel = cancel if s.config == nil { + s.mutex.Unlock() return nil, fmt.Errorf("config is not defined, please call login command first") } activeProf, err := s.profileManager.GetActiveProfileState() if err != nil { + s.mutex.Unlock() log.Errorf("failed to get active profile state: %v", err) return nil, fmt.Errorf("failed to get active profile state: %w", err) } if msg != nil && msg.ProfileName != nil { if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { + s.mutex.Unlock() log.Errorf("failed to switch profile: %v", err) return nil, fmt.Errorf("failed to switch profile: %w", err) } @@ -687,6 +707,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR activeProf, err = s.profileManager.GetActiveProfileState() if err != nil { + s.mutex.Unlock() log.Errorf("failed to get active profile state: %v", err) return nil, fmt.Errorf("failed to get active profile state: %w", err) } @@ -695,6 +716,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR config, _, err := s.getConfig(activeProf) if err != nil { + s.mutex.Unlock() log.Errorf("failed to get active profile config: %v", err) return nil, fmt.Errorf("failed to get active profile config: %w", err) } @@ -707,12 +729,9 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - var doAutoUpdate bool - if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { - doAutoUpdate = true - } - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + s.mutex.Unlock() return s.waitForUp(callerCtx) } @@ -838,14 +857,26 @@ func (s *Server) cleanupConnection() error { if s.actCancel == nil { return ErrServiceNotUp } + + // Capture the engine reference before cancelling the context. + // After actCancel(), the connectWithRetryRuns goroutine wakes up + // and sets connectClient.engine = nil, causing connectClient.Stop() + // to skip the engine shutdown entirely. + var engine *internal.Engine + if s.connectClient != nil { + engine = s.connectClient.Engine() + } + s.actCancel() if s.connectClient == nil { return nil } - if err := s.connectClient.Stop(); err != nil { - return err + if engine != nil { + if err := engine.Stop(); err != nil { + return err + } } s.connectClient = nil @@ -1312,6 +1343,65 @@ func (s *Server) WaitJWTToken( }, nil } +// ExposeService exposes a local port via the NetBird reverse proxy. +func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error { + s.mutex.Lock() + if !s.clientRunning { + s.mutex.Unlock() + return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first") + } + connectClient := s.connectClient + s.mutex.Unlock() + + if connectClient == nil { + return gstatus.Errorf(codes.FailedPrecondition, "client not initialized") + } + + engine := connectClient.Engine() + if engine == nil { + return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized") + } + + if engine.IsBlockInbound() { + return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first") + } + + mgr := engine.GetExposeManager() + if mgr == nil { + return gstatus.Errorf(codes.Internal, "expose manager not available") + } + + ctx := srv.Context() + + exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second) + defer exposeCancel() + + mgmReq := expose.NewRequest(req) + result, err := mgr.Expose(exposeCtx, *mgmReq) + if err != nil { + return err + } + + if err := srv.Send(&proto.ExposeServiceEvent{ + Event: &proto.ExposeServiceEvent_Ready{ + Ready: &proto.ExposeServiceReady{ + ServiceName: result.ServiceName, + ServiceUrl: result.ServiceURL, + Domain: result.Domain, + PortAutoAssigned: result.PortAutoAssigned, + }, + }, + }); err != nil { + return err + } + + err = mgr.KeepAlive(ctx, result.Domain) + if err != nil { + return err + } + return nil +} + func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { return false @@ -1541,16 +1631,23 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) features := &proto.GetFeaturesResponse{ DisableProfiles: s.checkProfilesDisabled(), DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + DisableNetworks: s.networksDisabled, } return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - if err := s.connectClient.Run(runningChan, s.logFile); err != nil { + client := internal.NewConnectClient(ctx, config, statusRecorder) + client.SetUpdateManager(s.updateManager) + client.SetSyncResponsePersistence(s.persistSyncResponse) + + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + if err := client.Run(runningChan, s.logFile); err != nil { return err } return nil @@ -1574,6 +1671,14 @@ func (s *Server) checkUpdateSettingsDisabled() bool { return false } +func (s *Server) startUpdateManagerForGUI() { + if s.updateManager == nil { + return + } + s.updateManager.Start(s.rootCtx) + s.updateManager.NotifyUI() +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() diff --git a/client/server/server_connect_test.go b/client/server/server_connect_test.go new file mode 100644 index 000000000..faea7da39 --- /dev/null +++ b/client/server/server_connect_test.go @@ -0,0 +1,187 @@ +package server + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" +) + +func newTestServer() *Server { + return &Server{ + rootCtx: context.Background(), + statusRecorder: peer.NewRecorder(""), + } +} + +func newDummyConnectClient(ctx context.Context) *internal.ConnectClient { + return internal.NewConnectClient(ctx, nil, nil) +} + +// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient +// under mutex protection so concurrent readers see a consistent value. +func TestConnectSetsClientWithMutex(t *testing.T) { + s := newTestServer() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Manually simulate what connect() does (without calling Run which panics without full setup) + client := newDummyConnectClient(ctx) + + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + // Verify the assignment is visible under mutex + s.mutex.Lock() + assert.Equal(t, client, s.connectClient, "connectClient should be set") + s.mutex.Unlock() +} + +// TestConcurrentConnectClientAccess validates that concurrent reads of +// s.connectClient under mutex don't race with a write. +func TestConcurrentConnectClientAccess(t *testing.T) { + s := newTestServer() + ctx := context.Background() + client := newDummyConnectClient(ctx) + + var wg sync.WaitGroup + nilCount := 0 + setCount := 0 + var mu sync.Mutex + + // Start readers + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s.mutex.Lock() + c := s.connectClient + s.mutex.Unlock() + + mu.Lock() + defer mu.Unlock() + if c == nil { + nilCount++ + } else { + setCount++ + } + }() + } + + // Simulate connect() writing under mutex + time.Sleep(5 * time.Millisecond) + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + wg.Wait() + + assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic") +} + +// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection +// properly nils out connectClient. +func TestCleanupConnection_ClearsConnectClient(t *testing.T) { + s := newTestServer() + _, cancel := context.WithCancel(context.Background()) + s.actCancel = cancel + + s.connectClient = newDummyConnectClient(context.Background()) + s.clientRunning = true + + err := s.cleanupConnection() + require.NoError(t, err) + + assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup") +} + +// TestCleanState_NilConnectClient validates that CleanState doesn't panic +// when connectClient is nil. +func TestCleanState_NilConnectClient(t *testing.T) { + s := newTestServer() + s.connectClient = nil + s.profileManager = nil // will cause error if it tries to proceed past the nil check + + // Should not panic — the nil check should prevent calling Status() on nil + assert.NotPanics(t, func() { + _, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true}) + }) +} + +// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic +// when connectClient is nil. +func TestDeleteState_NilConnectClient(t *testing.T) { + s := newTestServer() + s.connectClient = nil + s.profileManager = nil + + assert.NotPanics(t, func() { + _, _ = s.DeleteState(context.Background(), &proto.DeleteStateRequest{All: true}) + }) +} + +// TestDownThenUp_StaleRunningChan documents the known state issue where +// clientRunningChan from a previous connection is already closed, causing +// waitForUp() to return immediately on reconnect. +func TestDownThenUp_StaleRunningChan(t *testing.T) { + s := newTestServer() + + // Simulate state after a successful connection + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + close(s.clientRunningChan) // closed when engine started + s.clientGiveUpChan = make(chan struct{}) + s.connectClient = newDummyConnectClient(context.Background()) + + _, cancel := context.WithCancel(context.Background()) + s.actCancel = cancel + + // Simulate Down(): cleanupConnection sets connectClient = nil + s.mutex.Lock() + err := s.cleanupConnection() + s.mutex.Unlock() + require.NoError(t, err) + + // After cleanup: connectClient is nil, clientRunning still true + // (goroutine hasn't exited yet) + s.mutex.Lock() + assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup") + assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits") + s.mutex.Unlock() + + // waitForUp() returns immediately due to stale closed clientRunningChan + ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer ctxCancel() + + waitDone := make(chan error, 1) + go func() { + _, err := s.waitForUp(ctx) + waitDone <- err + }() + + select { + case err := <-waitDone: + assert.NoError(t, err, "waitForUp returns success on stale channel") + // But connectClient is still nil — this is the stale state issue + s.mutex.Lock() + assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success") + s.mutex.Unlock() + case <-time.After(1 * time.Second): + t.Fatal("waitForUp should have returned immediately due to stale closed channel") + } +} + +// TestConnectClient_EngineNilOnFreshClient validates that a newly created +// ConnectClient has nil Engine (before Run is called). +func TestConnectClient_EngineNilOnFreshClient(t *testing.T) { + client := newDummyConnectClient(context.Background()) + assert.Nil(t, client.Engine(), "engine should be nil on fresh ConnectClient") +} diff --git a/client/server/server_test.go b/client/server/server_test.go index 82079c531..772997575 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -36,6 +36,7 @@ import ( daemonProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false) + s := New(ctx, "debug", "", false, false, false) s.config = config @@ -113,7 +114,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) @@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) @@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 8e360175d..7f6847c43 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/server/sleep.go b/client/server/sleep.go new file mode 100644 index 000000000..7a83c75a6 --- /dev/null +++ b/client/server/sleep.go @@ -0,0 +1,46 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" +) + +// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces +type serverAgent struct { + s *Server +} + +func (a *serverAgent) Up(ctx context.Context) error { + _, err := a.s.Up(ctx, &proto.UpRequest{}) + return err +} + +func (a *serverAgent) Down(ctx context.Context) error { + _, err := a.s.Down(ctx, &proto.DownRequest{}) + return err +} + +func (a *serverAgent) Status() (internal.StatusType, error) { + return internal.CtxGetState(a.s.rootCtx).Status() +} + +// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. +func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { + switch req.GetType() { + case proto.OSLifecycleRequest_WAKEUP: + if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil { + return &proto.OSLifecycleResponse{}, err + } + case proto.OSLifecycleRequest_SLEEP: + if err := s.sleepHandler.HandleSleep(callerCtx); err != nil { + return &proto.OSLifecycleResponse{}, err + } + default: + log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) + } + return &proto.OSLifecycleResponse{}, nil +} diff --git a/client/server/state.go b/client/server/state.go index 1cf85cd37..f2d823465 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -12,7 +12,6 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -39,7 +38,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro // CleanState handles cleaning of states (performing cleanup operations) func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) { - if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) { return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } @@ -82,7 +81,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) ( // DeleteState handles deletion of states without cleanup func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) { - if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) { return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } @@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error { } // clean up any remaining routes independently of the state file - if !nbnet.AdvancedRouting() { - if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) - } + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/server/state_generic.go b/client/server/state_generic.go index 980ba0cda..86475ca42 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 019477d8e..b193d4dfa 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/server/triggerupdate.go b/client/server/triggerupdate.go new file mode 100644 index 000000000..ffcb527e7 --- /dev/null +++ b/client/server/triggerupdate.go @@ -0,0 +1,24 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +// TriggerUpdate initiates installation of the pending enforced version. +// It is called when the user clicks the install button in the UI (Mode 2 / enforced update). +func (s *Server) TriggerUpdate(ctx context.Context, _ *proto.TriggerUpdateRequest) (*proto.TriggerUpdateResponse, error) { + if s.updateManager == nil { + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: "update manager not available"}, nil + } + + if err := s.updateManager.Install(ctx); err != nil { + log.Warnf("TriggerUpdate failed: %v", err) + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: err.Error()}, nil + } + + return &proto.TriggerUpdateResponse{Success: true}, nil +} diff --git a/client/server/updateresult.go b/client/server/updateresult.go index 8e00d5062..8d1ef0e5f 100644 --- a/client/server/updateresult.go +++ b/client/server/updateresult.go @@ -5,7 +5,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/proto" ) diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 342da7303..7f72a72cf 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "github.com/netbirdio/netbird/client/internal/daemonaddr" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" @@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string { if runtime.GOOS == "windows" { return DefaultDaemonAddrWindows } - return DefaultDaemonAddr + return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr) } // DialOptions contains options for SSH connections diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index cc47fd2d2..6e584b2c3 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { return "", fmt.Errorf("get NetBird executable path: %w", err) } - hostLine := strings.Join(deduplicatedPatterns, " ") - config := fmt.Sprintf("Host %s\n", hostLine) - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - config += " PreferredAuthentications password,publickey,keyboard-interactive\n" - config += " PasswordAuthentication yes\n" - config += " PubkeyAuthentication yes\n" - config += " BatchMode no\n" - config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) - config += " StrictHostKeyChecking no\n" + hostList := strings.Join(deduplicatedPatterns, ",") + config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath) + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" if runtime.GOOS == "windows" { - config += " UserKnownHostsFile NUL\n" + config += " UserKnownHostsFile NUL\n" } else { - config += " UserKnownHostsFile /dev/null\n" + config += " UserKnownHostsFile /dev/null\n" } - config += " CheckHostIP no\n" - config += " LogLevel ERROR\n\n" + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" return config, nil } diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index dc3ad95b3..e7380c7f2 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) { assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") } +func TestManager_MatchHostFormat(t *testing.T) { + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + peers := []PeerSSHInfo{ + {Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"}, + {Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"}, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + // Must use "Match host" with comma-separated patterns, not a bare "Host" directive. + // A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block + // ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped. + assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive") + assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"", + "should use Match host with comma-separated patterns") +} + func TestManager_ForcedSSHConfig(t *testing.T) { // Set force environment variable t.Setenv(EnvForceSSHConfig, "true") diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 8897b9c7e..59007f75c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error { func (p *SSHProxy) handleSSHSession(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) if err != nil { @@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) { } if hasCommand { - if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + if err := serverSession.Run(session.RawCommand()); err != nil { log.Debugf("run command: %v", err) p.handleProxyExitCode(session, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index dba2e88da..b33d5f8f4 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) { cancel() } +// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting +// when forwarding commands to the backend. This is critical for tools like +// Ansible that send commands such as: +// +// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0' +// +// The single quotes must be preserved so the backend shell receives the +// subshell expression as a single argument to -c. +func TestSSHProxy_CommandQuoting(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + sshClient, cleanup := setupProxySSHClient(t) + defer cleanup() + + // These commands simulate what the SSH protocol delivers as exec payloads. + // When a user types: ssh host '/bin/sh -c "( echo hello )"' + // the local shell strips the outer single quotes, and the SSH exec request + // contains the raw string: /bin/sh -c "( echo hello )" + // + // The proxy must forward this string verbatim. Using session.Command() + // (shlex.Split + strings.Join) strips the inner double quotes, breaking + // the command on the backend. + tests := []struct { + name string + command string + expect string + }{ + { + name: "subshell_in_double_quotes", + command: `/bin/sh -c "( echo from-subshell ) && echo outer"`, + expect: "from-subshell\nouter\n", + }, + { + name: "printf_with_special_chars", + command: `/bin/sh -c "printf '%s\n' 'hello world'"`, + expect: "hello world\n", + }, + { + name: "nested_command_substitution", + command: `/bin/sh -c "echo $(echo nested)"`, + expect: "nested\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + session, err := sshClient.NewSession() + require.NoError(t, err) + defer func() { _ = session.Close() }() + + var stderrBuf bytes.Buffer + session.Stderr = &stderrBuf + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output(tc.command) + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + if stderrBuf.Len() > 0 { + t.Logf("stderr: %s", stderrBuf.String()) + } + require.NoError(t, err, "command should succeed: %s", tc.command) + assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command) + case <-time.After(5 * time.Second): + t.Fatalf("command timed out: %s", tc.command) + } + }) + } +} + +// setupProxySSHClient creates a full proxy test environment and returns +// an SSH client connected through the proxy to a backend NetBird SSH server. +func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) { + t.Helper() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audiences: []string{audience}, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, + }, + } + sshServer.UpdateSSHAuth(authConfig) + + sshServerAddr := server.StartTestServer(t, sshServer) + + mockDaemon := startMockDaemon(t) + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) + require.NoError(t, err) + + origStdin := os.Stdin + origStdout := os.Stdout + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + clientConn, proxyConn := net.Pipe() + + go func() { _, _ = io.Copy(stdinWriter, proxyConn) }() + go func() { _, _ = io.Copy(proxyConn, stdoutReader) }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + go func() { + _ = proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err) + + client := cryptossh.NewClient(sshClientConn, chans, reqs) + + cleanupFn := func() { + _ = client.Close() + _ = clientConn.Close() + cancel() + os.Stdin = origStdin + os.Stdout = origStdout + _ = sshServer.Stop() + mockDaemon.stop() + jwksServer.Close() + } + + return client, cleanupFn +} + type mockDaemonServer struct { proto.UnimplementedDaemonServiceServer hostKeys map[string][]byte diff --git a/client/ssh/server/getent_cgo_unix.go b/client/ssh/server/getent_cgo_unix.go new file mode 100644 index 000000000..4afbfc627 --- /dev/null +++ b/client/ssh/server/getent_cgo_unix.go @@ -0,0 +1,24 @@ +//go:build cgo && !osusergo && !windows + +package server + +import "os/user" + +// lookupWithGetent with CGO delegates directly to os/user.Lookup. +// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through +// the NSS stack natively. If it fails, the user truly doesn't exist and +// getent would also fail. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent with CGO delegates directly to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// groupIdsWithFallback with CGO delegates directly to user.GroupIds. +// libc's getgrouplist handles NSS groups natively. +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/getent_nocgo_unix.go b/client/ssh/server/getent_nocgo_unix.go new file mode 100644 index 000000000..314daae4c --- /dev/null +++ b/client/ssh/server/getent_nocgo_unix.go @@ -0,0 +1,74 @@ +//go:build (!cgo || osusergo) && !windows + +package server + +import ( + "os" + "os/user" + "strconv" + + log "github.com/sirupsen/logrus" +) + +// lookupWithGetent looks up a user by name, falling back to getent if os/user fails. +// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users. +// getent goes through the host's NSS stack. +func lookupWithGetent(username string) (*user.User, error) { + u, err := user.Lookup(username) + if err == nil { + return u, nil + } + + stdErr := err + log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err) + + u, _, getentErr := runGetent(username) + if getentErr != nil { + log.Debugf("getent fallback for %q also failed: %v", username, getentErr) + return nil, stdErr + } + + return u, nil +} + +// currentUserWithGetent gets the current user, falling back to getent if os/user fails. +func currentUserWithGetent() (*user.User, error) { + u, err := user.Current() + if err == nil { + return u, nil + } + + stdErr := err + uid := strconv.Itoa(os.Getuid()) + log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err) + + u, _, getentErr := runGetent(uid) + if getentErr != nil { + return nil, stdErr + } + + return u, nil +} + +// groupIdsWithFallback gets group IDs for a user via the id command first, +// falling back to user.GroupIds(). +// NOTE: unlike lookupWithGetent/currentUserWithGetent which try stdlib first, +// this intentionally tries `id -G` first because without CGO, user.GroupIds() +// only reads /etc/group and silently returns incomplete results for NSS users +// (no error, just missing groups). The id command goes through NSS and returns +// the full set. +func groupIdsWithFallback(u *user.User) ([]string, error) { + ids, err := runIdGroups(u.Username) + if err == nil { + return ids, nil + } + + log.Debugf("id -G %q failed, falling back to user.GroupIds(): %v", u.Username, err) + + ids, stdErr := u.GroupIds() + if stdErr != nil { + return nil, stdErr + } + + return ids, nil +} diff --git a/client/ssh/server/getent_test.go b/client/ssh/server/getent_test.go new file mode 100644 index 000000000..5eac2fdbe --- /dev/null +++ b/client/ssh/server/getent_test.go @@ -0,0 +1,172 @@ +package server + +import ( + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLookupWithGetent_CurrentUser(t *testing.T) { + // The current user should always be resolvable on any platform + current, err := user.Current() + require.NoError(t, err) + + u, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Username, u.Username) + assert.Equal(t, current.Uid, u.Uid) + assert.Equal(t, current.Gid, u.Gid) +} + +func TestLookupWithGetent_NonexistentUser(t *testing.T) { + _, err := lookupWithGetent("nonexistent_user_xyzzy_12345") + require.Error(t, err, "should fail for nonexistent user") +} + +func TestCurrentUserWithGetent(t *testing.T) { + stdUser, err := user.Current() + require.NoError(t, err) + + u, err := currentUserWithGetent() + require.NoError(t, err) + assert.Equal(t, stdUser.Uid, u.Uid) + assert.Equal(t, stdUser.Username, u.Username) +} + +func TestGroupIdsWithFallback_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } + } +} + +func TestGetShellFromGetent_CurrentUser(t *testing.T) { + if runtime.GOOS == "windows" { + // Windows stub always returns empty, which is correct + shell := getShellFromGetent("1000") + assert.Empty(t, shell, "Windows stub should return empty") + return + } + + current, err := user.Current() + require.NoError(t, err) + + // getent may not be available on all systems (e.g., macOS without Homebrew getent) + shell := getShellFromGetent(current.Uid) + if shell == "" { + t.Log("getShellFromGetent returned empty, getent may not be available") + return + } + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} + +func TestLookupWithGetent_RootUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("no root user on Windows") + } + + u, err := lookupWithGetent("root") + if err != nil { + t.Skip("root user not available on this system") + } + assert.Equal(t, "0", u.Uid, "root should have UID 0") +} + +// TestIntegration_FullLookupChain exercises the complete user lookup chain +// against the real system, testing that all wrappers (lookupWithGetent, +// currentUserWithGetent, groupIdsWithFallback, getShellFromGetent) produce +// consistent and correct results when composed together. +func TestIntegration_FullLookupChain(t *testing.T) { + // Step 1: currentUserWithGetent must resolve the running user. + current, err := currentUserWithGetent() + require.NoError(t, err, "currentUserWithGetent must resolve the running user") + require.NotEmpty(t, current.Uid) + require.NotEmpty(t, current.Username) + + // Step 2: lookupWithGetent by the same username must return matching identity. + byName, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID") + assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID") + assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home") + + // Step 3: groupIdsWithFallback must return at least the primary GID. + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "user must have at least one group") + + foundPrimary := false + for _, gid := range groups { + if runtime.GOOS != "windows" { + _, err := strconv.ParseUint(gid, 10, 32) + require.NoError(t, err, "group ID %q must be a valid uint32", gid) + } + if gid == current.Gid { + foundPrimary = true + } + } + assert.True(t, foundPrimary, "primary GID %s should appear in supplementary groups", current.Gid) + + // Step 4: getShellFromGetent should either return a valid shell path or empty + // (empty is OK when getent is not available, e.g. macOS without Homebrew getent). + if runtime.GOOS != "windows" { + shell := getShellFromGetent(current.Uid) + if shell != "" { + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + } + } +} + +// TestIntegration_LookupAndGroupsConsistency verifies that a user resolved via +// lookupWithGetent can have their groups resolved via groupIdsWithFallback, +// testing the handoff between the two functions as used by the SSH server. +func TestIntegration_LookupAndGroupsConsistency(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + // Simulate the SSH server flow: lookup user, then get their groups. + resolved, err := lookupWithGetent(current.Username) + require.NoError(t, err) + + groups, err := groupIdsWithFallback(resolved) + require.NoError(t, err) + require.NotEmpty(t, groups, "resolved user must have groups") + + // On Unix, all returned GIDs must be valid numeric values. + // On Windows, group IDs are SIDs (e.g., "S-1-5-32-544"). + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be numeric", gid) + } + } +} + +// TestIntegration_ShellLookupChain tests the full shell resolution chain +// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix. +func TestIntegration_ShellLookupChain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix shell lookup not applicable on Windows") + } + + current, err := user.Current() + require.NoError(t, err) + + // getUserShell is the top-level function used by the SSH server. + shell := getUserShell(current.Uid) + require.NotEmpty(t, shell, "getUserShell must always return a shell") + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} diff --git a/client/ssh/server/getent_unix.go b/client/ssh/server/getent_unix.go new file mode 100644 index 000000000..18edb2fdf --- /dev/null +++ b/client/ssh/server/getent_unix.go @@ -0,0 +1,122 @@ +//go:build !windows + +package server + +import ( + "context" + "fmt" + "os/exec" + "os/user" + "runtime" + "strings" + "time" +) + +const getentTimeout = 5 * time.Second + +// getShellFromGetent gets a user's login shell via getent by UID. +// This is needed even with CGO because getShellFromPasswd reads /etc/passwd +// directly and won't find NSS-provided users there. +func getShellFromGetent(userID string) string { + _, shell, err := runGetent(userID) + if err != nil { + return "" + } + return shell +} + +// runGetent executes `getent passwd ` and returns the user and login shell. +func runGetent(query string) (*user.User, string, error) { + if !validateGetentInput(query) { + return nil, "", fmt.Errorf("invalid getent input: %q", query) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "getent", "passwd", query).Output() + if err != nil { + return nil, "", fmt.Errorf("getent passwd %s: %w", query, err) + } + + return parseGetentPasswd(string(out)) +} + +// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell" +func parseGetentPasswd(output string) (*user.User, string, error) { + fields := strings.SplitN(strings.TrimSpace(output), ":", 8) + if len(fields) < 6 { + return nil, "", fmt.Errorf("unexpected getent output (need 6+ fields): %q", output) + } + + if fields[0] == "" || fields[2] == "" || fields[3] == "" { + return nil, "", fmt.Errorf("missing required fields in getent output: %q", output) + } + + var shell string + if len(fields) >= 7 { + shell = fields[6] + } + + return &user.User{ + Username: fields[0], + Uid: fields[2], + Gid: fields[3], + Name: fields[4], + HomeDir: fields[5], + }, shell, nil +} + +// validateGetentInput checks that the input is safe to pass to getent or id. +// Allows POSIX usernames, numeric UIDs, and common NSS extensions +// (@ for Kerberos, $ for Samba, + for NIS compat). +func validateGetentInput(input string) bool { + maxLen := 32 + if runtime.GOOS == "linux" { + maxLen = 256 + } + + if len(input) == 0 || len(input) > maxLen { + return false + } + + for _, r := range input { + if isAllowedGetentChar(r) { + continue + } + return false + } + return true +} + +func isAllowedGetentChar(r rune) bool { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' { + return true + } + switch r { + case '.', '_', '-', '@', '+', '$': + return true + } + return false +} + +// runIdGroups runs `id -G ` and returns the space-separated group IDs. +func runIdGroups(username string) ([]string, error) { + if !validateGetentInput(username) { + return nil, fmt.Errorf("invalid username for id command: %q", username) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "id", "-G", username).Output() + if err != nil { + return nil, fmt.Errorf("id -G %s: %w", username, err) + } + + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return nil, fmt.Errorf("id -G %s: empty output", username) + } + return strings.Fields(trimmed), nil +} diff --git a/client/ssh/server/getent_unix_test.go b/client/ssh/server/getent_unix_test.go new file mode 100644 index 000000000..e44563b79 --- /dev/null +++ b/client/ssh/server/getent_unix_test.go @@ -0,0 +1,410 @@ +//go:build !windows + +package server + +import ( + "os/exec" + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGetentPasswd(t *testing.T) { + tests := []struct { + name string + input string + wantUser *user.User + wantShell string + wantErr bool + errContains string + }{ + { + name: "standard entry", + input: "alice:x:1001:1001:Alice Smith:/home/alice:/bin/bash\n", + wantUser: &user.User{ + Username: "alice", + Uid: "1001", + Gid: "1001", + Name: "Alice Smith", + HomeDir: "/home/alice", + }, + wantShell: "/bin/bash", + }, + { + name: "root entry", + input: "root:x:0:0:root:/root:/bin/bash", + wantUser: &user.User{ + Username: "root", + Uid: "0", + Gid: "0", + Name: "root", + HomeDir: "/root", + }, + wantShell: "/bin/bash", + }, + { + name: "empty gecos field", + input: "svc:x:999:999::/var/lib/svc:/usr/sbin/nologin", + wantUser: &user.User{ + Username: "svc", + Uid: "999", + Gid: "999", + Name: "", + HomeDir: "/var/lib/svc", + }, + wantShell: "/usr/sbin/nologin", + }, + { + name: "gecos with commas", + input: "john:x:1002:1002:John Doe,Room 101,555-1234,555-4321:/home/john:/bin/zsh", + wantUser: &user.User{ + Username: "john", + Uid: "1002", + Gid: "1002", + Name: "John Doe,Room 101,555-1234,555-4321", + HomeDir: "/home/john", + }, + wantShell: "/bin/zsh", + }, + { + name: "remote user with large UID", + input: "remoteuser:*:50001:50001:Remote User:/home/remoteuser:/bin/bash\n", + wantUser: &user.User{ + Username: "remoteuser", + Uid: "50001", + Gid: "50001", + Name: "Remote User", + HomeDir: "/home/remoteuser", + }, + wantShell: "/bin/bash", + }, + { + name: "no shell field (only 6 fields)", + input: "minimal:x:1000:1000::/home/minimal", + wantUser: &user.User{ + Username: "minimal", + Uid: "1000", + Gid: "1000", + Name: "", + HomeDir: "/home/minimal", + }, + wantShell: "", + }, + { + name: "too few fields", + input: "bad:x:1000", + wantErr: true, + errContains: "need 6+ fields", + }, + { + name: "empty username", + input: ":x:1000:1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty UID", + input: "test:x::1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty GID", + input: "test:x:1000:::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty input", + input: "", + wantErr: true, + errContains: "need 6+ fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, shell, err := parseGetentPasswd(tt.input) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantUser.Username, u.Username, "username") + assert.Equal(t, tt.wantUser.Uid, u.Uid, "UID") + assert.Equal(t, tt.wantUser.Gid, u.Gid, "GID") + assert.Equal(t, tt.wantUser.Name, u.Name, "name/gecos") + assert.Equal(t, tt.wantUser.HomeDir, u.HomeDir, "home directory") + assert.Equal(t, tt.wantShell, shell, "shell") + }) + } +} + +func TestValidateGetentInput(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"normal username", "alice", true}, + {"numeric UID", "1001", true}, + {"dots and underscores", "alice.bob_test", true}, + {"hyphen", "alice-bob", true}, + {"kerberos principal", "user@REALM", true}, + {"samba machine account", "MACHINE$", true}, + {"NIS compat", "+user", true}, + {"empty", "", false}, + {"null byte", "alice\x00bob", false}, + {"newline", "alice\nbob", false}, + {"tab", "alice\tbob", false}, + {"control char", "alice\x01bob", false}, + {"DEL char", "alice\x7fbob", false}, + {"space rejected", "alice bob", false}, + {"semicolon rejected", "alice;bob", false}, + {"backtick rejected", "alice`bob", false}, + {"pipe rejected", "alice|bob", false}, + {"33 chars exceeds non-linux max", makeLongString(33), runtime.GOOS == "linux"}, + {"256 chars at linux max", makeLongString(256), runtime.GOOS == "linux"}, + {"257 chars exceeds all limits", makeLongString(257), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, validateGetentInput(tt.input)) + }) + } +} + +func makeLongString(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = 'a' + } + return string(b) +} + +func TestRunGetent_RootUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, shell, err := runGetent("root") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) + assert.Equal(t, "0", u.Gid) + assert.NotEmpty(t, shell, "root should have a shell") +} + +func TestRunGetent_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, _, err := runGetent("0") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) +} + +func TestRunGetent_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + _, _, err := runGetent("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunGetent_InvalidInput(t *testing.T) { + _, _, err := runGetent("") + assert.Error(t, err) + + _, _, err = runGetent("user\x00name") + assert.Error(t, err) +} + +func TestRunGetent_NotAvailable(t *testing.T) { + if _, err := exec.LookPath("getent"); err == nil { + t.Skip("getent is available, can't test missing case") + } + + _, _, err := runGetent("root") + assert.Error(t, err, "should fail when getent is not installed") +} + +func TestRunIdGroups_CurrentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + groups, err := runIdGroups(current.Username) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } +} + +func TestRunIdGroups_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + _, err := runIdGroups("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunIdGroups_InvalidInput(t *testing.T) { + _, err := runIdGroups("") + assert.Error(t, err) + + _, err = runIdGroups("user\x00name") + assert.Error(t, err) +} + +func TestGetentResultsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Username) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") + assert.Equal(t, current.Gid, getentUser.Gid, "GID should match") + assert.Equal(t, current.HomeDir, getentUser.HomeDir, "home directory should match") +} + +func TestGetentResultsMatchStdlib_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Uid) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") +} + +func TestIdGroupsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + stdGroups, err := current.GroupIds() + if err != nil { + t.Skip("os/user.GroupIds() not working, likely CGO_ENABLED=0") + } + + idGroups, err := runIdGroups(current.Username) + require.NoError(t, err) + + // Deduplicate both lists: id -G can return duplicates (e.g., root in Docker) + // and ElementsMatch treats duplicates as distinct. + assert.ElementsMatch(t, uniqueStrings(stdGroups), uniqueStrings(idGroups), "id -G should return same groups as os/user") +} + +func uniqueStrings(ss []string) []string { + seen := make(map[string]struct{}, len(ss)) + out := make([]string, 0, len(ss)) + for _, s := range ss { + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +// TestGetShellFromPasswd_CurrentUser verifies that getShellFromPasswd correctly +// reads the current user's shell from /etc/passwd by comparing it against what +// getent reports (which goes through NSS). +func TestGetShellFromPasswd_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + shell := getShellFromPasswd(current.Uid) + if shell == "" { + t.Skip("current user not found in /etc/passwd (may be an NSS-only user)") + } + + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + + if _, err := exec.LookPath("getent"); err == nil { + _, getentShell, getentErr := runGetent(current.Uid) + if getentErr == nil && getentShell != "" { + assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent") + } + } +} + +// TestGetShellFromPasswd_RootUser verifies that getShellFromPasswd can read +// root's shell from /etc/passwd. Root is guaranteed to be in /etc/passwd on +// any standard Unix system. +func TestGetShellFromPasswd_RootUser(t *testing.T) { + shell := getShellFromPasswd("0") + require.NotEmpty(t, shell, "root (UID 0) must be in /etc/passwd") + assert.True(t, shell[0] == '/', "root shell should be an absolute path, got %q", shell) +} + +// TestGetShellFromPasswd_NonexistentUID verifies that getShellFromPasswd +// returns empty for a UID that doesn't exist in /etc/passwd. +func TestGetShellFromPasswd_NonexistentUID(t *testing.T) { + shell := getShellFromPasswd("4294967294") + assert.Empty(t, shell, "nonexistent UID should return empty shell") +} + +// TestGetShellFromPasswd_MatchesGetentForKnownUsers reads /etc/passwd directly +// and cross-validates every entry against getent to ensure parseGetentPasswd +// and getShellFromPasswd agree on shell values. +func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available") + } + + // Pick a few well-known system UIDs that are virtually always in /etc/passwd. + uids := []string{"0"} // root + + current, err := user.Current() + require.NoError(t, err) + uids = append(uids, current.Uid) + + for _, uid := range uids { + passwdShell := getShellFromPasswd(uid) + if passwdShell == "" { + continue + } + + _, getentShell, err := runGetent(uid) + if err != nil { + continue + } + + assert.Equal(t, getentShell, passwdShell, "shell mismatch for UID %s", uid) + } +} diff --git a/client/ssh/server/getent_windows.go b/client/ssh/server/getent_windows.go new file mode 100644 index 000000000..3e76b3e8e --- /dev/null +++ b/client/ssh/server/getent_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package server + +import "os/user" + +// lookupWithGetent on Windows just delegates to os/user.Lookup. +// Windows does not use NSS/getent; its user lookup works without CGO. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent on Windows just delegates to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection. +func getShellFromGetent(_ string) string { + return "" +} + +// groupIdsWithFallback on Windows just delegates to u.GroupIds(). +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 1ddb60f8e..82d3b700f 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -46,8 +46,10 @@ const ( cmdSFTP = "" cmdNonInteractive = "" - // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server - DefaultJWTMaxTokenAge = 5 * 60 + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server. + // Set to 10 minutes to accommodate identity providers like Azure Entra ID + // that backdate the iat claim by up to 5 minutes. + DefaultJWTMaxTokenAge = 10 * 60 ) var ( @@ -282,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) { // Stop closes the SSH server func (s *Server) Stop() error { s.mu.Lock() - defer s.mu.Unlock() - - if s.sshServer == nil { + sshServer := s.sshServer + if sshServer == nil { + s.mu.Unlock() return nil } + s.sshServer = nil + s.listener = nil + s.mu.Unlock() - if err := s.sshServer.Close(); err != nil { + // Close outside the lock: session handlers need s.mu for unregisterSession. + if err := sshServer.Close(); err != nil { log.Debugf("close SSH server: %v", err) } - s.sshServer = nil - s.listener = nil - + s.mu.Lock() maps.Clear(s.sessions) maps.Clear(s.pendingAuthJWT) maps.Clear(s.connections) @@ -305,6 +309,7 @@ func (s *Server) Stop() error { } } maps.Clear(s.remoteForwardListeners) + s.mu.Unlock() return nil } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index f12a75961..0e531bb96 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) { } ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" if isPty && !hasCommand { // ssh - PTY interactive session (login) diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go index fea9d2910..1e8ff5e31 100644 --- a/client/ssh/server/shell.go +++ b/client/ssh/server/shell.go @@ -49,10 +49,14 @@ func getWindowsUserShell() string { return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` } -// getUnixUserShell returns the shell for Unix-like systems +// getUnixUserShell returns the shell for Unix-like systems. +// Tries /etc/passwd first (fast, no subprocess), falls back to getent for NSS users. func getUnixUserShell(userID string) string { - shell := getShellFromPasswd(userID) - if shell != "" { + if shell := getShellFromPasswd(userID); shell != "" { + return shell + } + + if shell := getShellFromGetent(userID); shell != "" { return shell } diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go index 799882cbb..bc2aa2d7d 100644 --- a/client/ssh/server/user_utils.go +++ b/client/ssh/server/user_utils.go @@ -23,8 +23,8 @@ func isPlatformUnix() bool { // Dependency injection variables for testing - allows mocking dynamic runtime checks var ( - getCurrentUser = user.Current - lookupUser = user.Lookup + getCurrentUser = currentUserWithGetent + lookupUser = lookupWithGetent getCurrentOS = func() string { return runtime.GOOS } getIsProcessPrivileged = isCurrentProcessPrivileged diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index d80b77042..220e2240f 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -146,32 +146,30 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u } gid := uint32(gid64) - groups, err := s.getSupplementaryGroups(localUser.Username) - if err != nil { - log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups, err := s.getSupplementaryGroups(localUser) + if err != nil || len(groups) == 0 { + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + } groups = []uint32{gid} } return uid, gid, groups, nil } -// getSupplementaryGroups retrieves supplementary group IDs for a user -func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { - u, err := user.Lookup(username) +// getSupplementaryGroups retrieves supplementary group IDs for a user. +// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds. +func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) { + groupIDStrings, err := groupIdsWithFallback(u) if err != nil { - return nil, fmt.Errorf("lookup user %s: %w", username, err) - } - - groupIDStrings, err := u.GroupIds() - if err != nil { - return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err) } groups := make([]uint32, len(groupIDStrings)) for i, gidStr := range groupIDStrings { gid64, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, u.Username, err) } groups[i] = uint32(gid64) } diff --git a/client/status/status.go b/client/status/status.go index f13163a41..8c932bbab 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -25,6 +25,38 @@ import ( "github.com/netbirdio/netbird/version" ) +// DaemonStatus represents the current state of the NetBird daemon. +// These values mirror internal.StatusType but are defined here to avoid an import cycle. +type DaemonStatus string + +const ( + DaemonStatusIdle DaemonStatus = "Idle" + DaemonStatusConnecting DaemonStatus = "Connecting" + DaemonStatusConnected DaemonStatus = "Connected" + DaemonStatusNeedsLogin DaemonStatus = "NeedsLogin" + DaemonStatusLoginFailed DaemonStatus = "LoginFailed" + DaemonStatusSessionExpired DaemonStatus = "SessionExpired" +) + +// ParseDaemonStatus converts a raw status string to DaemonStatus. +// Unrecognized values are preserved as-is to remain visible during version skew. +func ParseDaemonStatus(s string) DaemonStatus { + return DaemonStatus(s) +} + +// ConvertOptions holds parameters for ConvertToStatusOutputOverview. +type ConvertOptions struct { + Anonymize bool + DaemonVersion string + DaemonStatus DaemonStatus + StatusFilter string + PrefixNamesFilter []string + PrefixNamesFilterMap map[string]struct{} + IPsFilter map[string]struct{} + ConnectionTypeFilter string + ProfileName string +} + type PeerStateDetailOutput struct { FQDN string `json:"fqdn" yaml:"fqdn"` IP string `json:"netbirdIp" yaml:"netbirdIp"` @@ -102,6 +134,7 @@ type OutputOverview struct { Peers PeersStateOutput `json:"peers" yaml:"peers"` CliVersion string `json:"cliVersion" yaml:"cliVersion"` DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` + DaemonStatus DaemonStatus `json:"daemonStatus" yaml:"daemonStatus"` ManagementState ManagementStateOutput `json:"management" yaml:"management"` SignalState SignalStateOutput `json:"signal" yaml:"signal"` Relays RelayStateOutput `json:"relays" yaml:"relays"` @@ -120,7 +153,8 @@ type OutputOverview struct { SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"` } -func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { +// ConvertToStatusOutputOverview converts protobuf status to the output overview. +func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertOptions) OutputOverview { managementState := pbFullStatus.GetManagementState() managementOverview := ManagementStateOutput{ URL: managementState.GetURL(), @@ -137,12 +171,13 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da relayOverview := mapRelays(pbFullStatus.GetRelays()) sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) - peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter) overview := OutputOverview{ Peers: peersOverview, CliVersion: version.NetbirdVersion(), - DaemonVersion: daemonVersion, + DaemonVersion: opts.DaemonVersion, + DaemonStatus: opts.DaemonStatus, ManagementState: managementOverview, SignalState: signalOverview, Relays: relayOverview, @@ -157,11 +192,11 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), - ProfileName: profName, + ProfileName: opts.ProfileName, SSHServerState: sshServerOverview, } - if anon { + if opts.Anonymize { anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) anonymizeOverview(anonymizer, &overview) } diff --git a/client/status/status_test.go b/client/status/status_test.go index b02d78d64..7754eebae 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -176,6 +176,7 @@ var overview = OutputOverview{ Events: []SystemEventOutput{}, CliVersion: version.NetbirdVersion(), DaemonVersion: "0.14.1", + DaemonStatus: DaemonStatusConnected, ManagementState: ManagementStateOutput{ URL: "my-awesome-management.com:443", Connected: true, @@ -238,7 +239,10 @@ var overview = OutputOverview{ } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "") + convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), ConvertOptions{ + DaemonVersion: resp.GetDaemonVersion(), + DaemonStatus: ParseDaemonStatus(resp.GetStatus()), + }) assert.Equal(t, overview, convertedResult) } @@ -329,6 +333,7 @@ func TestParsingToJSON(t *testing.T) { }, "cliVersion": "development", "daemonVersion": "0.14.1", + "daemonStatus": "Connected", "management": { "url": "my-awesome-management.com:443", "connected": true, @@ -452,6 +457,7 @@ func TestParsingToYAML(t *testing.T) { networks: [] cliVersion: development daemonVersion: 0.14.1 +daemonStatus: Connected management: url: my-awesome-management.com:443 connected: true diff --git a/client/system/info.go b/client/system/info.go index 01176e765..175d1f07f 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -2,7 +2,6 @@ package system import ( "context" - "net" "net/netip" "strings" @@ -145,56 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string { return v } -func networkAddresses() ([]NetworkAddress, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - var netAddresses []NetworkAddress - for _, iface := range interfaces { - if iface.HardwareAddr.String() == "" { - continue - } - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, address := range addrs { - ipNet, ok := address.(*net.IPNet) - if !ok { - continue - } - - if ipNet.IP.IsLoopback() { - continue - } - - netAddr := NetworkAddress{ - NetIP: netip.MustParsePrefix(ipNet.String()), - Mac: iface.HardwareAddr.String(), - } - - if isDuplicated(netAddresses, netAddr) { - continue - } - - netAddresses = append(netAddresses, netAddr) - } - } - return netAddresses, nil -} - -func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { - for _, duplicated := range addresses { - if duplicated.NetIP == addr.NetIP { - return true - } - } - return false -} - // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { log.Debugf("gathering system information with checks: %d", len(checks)) diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index 8e1353151..755172842 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info { systemHostname, _ := os.Hostname() + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + return &Info{ - GoOS: runtime.GOOS, - Kernel: osInfo[0], - Platform: runtime.GOARCH, - OS: osName, - OSVersion: osVersion, - Hostname: extractDeviceName(ctx, systemHostname), - CPUs: runtime.NumCPU(), - NetbirdVersion: version.NetbirdVersion(), - UIVersion: extractUserAgent(ctx), - KernelVersion: osInfo[1], - Environment: env, + GoOS: runtime.GOOS, + Kernel: osInfo[0], + Platform: runtime.GOARCH, + OS: osName, + OSVersion: osVersion, + Hostname: extractDeviceName(ctx, systemHostname), + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + UIVersion: extractUserAgent(ctx), + KernelVersion: osInfo[1], + NetworkAddresses: addrs, + Environment: env, } } diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 322609db4..ad42b1edf 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -2,12 +2,16 @@ package system import ( "context" + "net" + "net/netip" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update func UpdateStaticInfoAsync() { // do nothing } @@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() { // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - // Convert fixed-size byte arrays to Go strings sysName := extractOsName(ctx, "sysName") swVersion := extractOsVersion(ctx, "swVersion") - gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion} + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + + gio := &Info{ + Kernel: sysName, + OSVersion: swVersion, + Platform: "unknown", + OS: sysName, + GoOS: runtime.GOOS, + CPUs: runtime.NumCPU(), + KernelVersion: swVersion, + NetworkAddresses: addrs, + } gio.Hostname = extractDeviceName(ctx, "hostname") gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) @@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info { return gio } +// networkAddresses returns the list of network addresses on iOS. +// On iOS, hardware (MAC) addresses are not available due to Apple's privacy +// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we +// leave Mac empty to match Android's behavior. We also skip the HardwareAddr +// check that other platforms use and filter out link-local addresses as they +// are not useful for posture checks. +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: ""}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} + // checkFileAndProcess checks if the file path exists and if a process is running at that path. func checkFileAndProcess(paths []string) ([]File, error) { return []File{}, nil diff --git a/client/system/network_addr.go b/client/system/network_addr.go new file mode 100644 index 000000000..5423cf8ad --- /dev/null +++ b/client/system/network_addr.go @@ -0,0 +1,66 @@ +//go:build !ios + +package system + +import ( + "net" + "net/netip" +) + +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + if iface.HardwareAddr.String() == "" { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + mac := iface.HardwareAddr.String() + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address, mac) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: mac}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0290e17d5..c149b2152 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -34,7 +34,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" @@ -308,12 +307,14 @@ type serviceClient struct { sshJWTCacheTTL int connected bool - update *version.Update daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + isEnforcedUpdate bool + lastNotifiedVersion string settingsEnabled bool profilesEnabled bool + networksEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -323,7 +324,8 @@ type serviceClient struct { exitNodeMu sync.Mutex mExitNodeItems []menuHandler - exitNodeStates []exitNodeState + exitNodeRetryCancel context.CancelFunc + mExitNodeSeparator *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window @@ -367,7 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdateAndStart("nb/client-ui"), + networksEnabled: true, } s.eventHandler = newEventHandler(s) @@ -828,7 +830,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log return nil } -func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -850,9 +852,7 @@ func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) e return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ - AutoUpdate: protobuf.Bool(wannaAutoUpdate), - }); err != nil { + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { return fmt.Errorf("start connection: %w", err) } @@ -922,9 +922,11 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetIcon(s.icConnectedDot) s.mUp.Disable() s.mDown.Enable() - s.mNetworks.Enable() - s.mExitNode.Enable() - go s.updateExitNodes() + if s.networksEnabled { + s.mNetworks.Enable() + s.mExitNode.Enable() + } + s.startExitNodeRefresh() systrayIconState = true case status.Status == string(internal.StatusConnecting): s.setConnectingStatus() @@ -933,13 +935,13 @@ func (s *serviceClient) updateStatus() error { systrayIconState = false } - // the updater struct notify by the upgrades available only, but if meanwhile the daemon has successfully - // updated must reset the mUpdate visibility state + // if the daemon version changed (e.g. after a successful update), reset the update indication if s.daemonVersion != status.DaemonVersion { - s.mUpdate.Hide() + if s.daemonVersion != "" { + s.mUpdate.Hide() + s.isUpdateIconActive = false + } s.daemonVersion = status.DaemonVersion - - s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion) if !s.isUpdateIconActive { if systrayIconState { systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected) @@ -985,6 +987,7 @@ func (s *serviceClient) setDisconnectedStatus() { s.mUp.Enable() s.mNetworks.Disable() s.mExitNode.Disable() + s.cancelExitNodeRetry() go s.updateExitNodes() } @@ -1090,19 +1093,18 @@ func (s *serviceClient) onTrayReady() { // update exit node menu in case service is already connected go s.updateExitNodes() - s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { + // Check features before status so menus respect disable flags before being enabled + s.checkAndUpdateFeatures() + err := s.updateStatus() if err != nil { log.Errorf("error while updating status: %v", err) } - // Check features periodically to handle daemon restarts - s.checkAndUpdateFeatures() - time.Sleep(2 * time.Second) } }() @@ -1134,6 +1136,13 @@ func (s *serviceClient) onTrayReady() { } } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + if newVersion, ok := event.Metadata["new_version_available"]; ok { + _, enforced := event.Metadata["enforced"] + log.Infof("received new_version_available event: version=%s enforced=%v", newVersion, enforced) + s.onUpdateAvailable(newVersion, enforced) + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) @@ -1294,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() { s.mProfile.setEnabled(profilesEnabled) } } + + // Update networks and exit node menus based on current features + s.networksEnabled = features == nil || !features.DisableNetworks + if s.networksEnabled && s.connected { + s.mNetworks.Enable() + s.mExitNode.Enable() + } else { + s.mNetworks.Disable() + s.mExitNode.Disable() + } } // getFeatures from the daemon to determine which features are enabled/disabled. @@ -1506,10 +1525,18 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { return &config } -func (s *serviceClient) onUpdateAvailable() { +func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) { s.updateIndicationLock.Lock() defer s.updateIndicationLock.Unlock() + s.isEnforcedUpdate = enforced + if enforced { + s.mUpdate.SetTitle("Install version " + newVersion) + } else { + s.lastNotifiedVersion = "" + s.mUpdate.SetTitle("Download latest version") + } + s.mUpdate.Show() s.isUpdateIconActive = true @@ -1518,6 +1545,11 @@ func (s *serviceClient) onUpdateAvailable() { } else { systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected) } + + if enforced && s.lastNotifiedVersion != newVersion { + s.lastNotifiedVersion = newVersion + s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install")) + } } // onSessionExpire sends a notification to the user when the session expires. diff --git a/client/ui/debug.go b/client/ui/debug.go index 29f73a66a..4ebe4d675 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -24,9 +24,10 @@ import ( // Initial state for the debug collection type debugInitialState struct { - wasDown bool - logLevel proto.LogLevel - isLevelTrace bool + wasDown bool + needsRestoreUp bool + logLevel proto.LogLevel + isLevelTrace bool } // Debug collection parameters @@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, enablePersistence bool, -) error { +) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service up: %v", err) + log.Warnf("failed to bring service up: %v", err) + } else { + log.Info("Service brought up for debug") + time.Sleep(time.Second * 10) } - log.Info("Service brought up for debug") - time.Sleep(time.Second * 10) } if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { - return fmt.Errorf("set log level to TRACE: %v", err) + log.Warnf("failed to set log level to TRACE: %v", err) + } else { + log.Info("Log level set to TRACE for debug") } - log.Info("Log level set to TRACE for debug") } if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - return fmt.Errorf("bring service down: %v", err) + log.Warnf("failed to bring service down: %v", err) + } else { + state.needsRestoreUp = !state.wasDown + time.Sleep(time.Second) } - time.Sleep(time.Second) if enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { - return fmt.Errorf("enable sync response persistence: %v", err) + log.Warnf("failed to enable sync response persistence: %v", err) + } else { + log.Info("Sync response persistence enabled for debug") } - log.Info("Sync response persistence enabled for debug") } if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service back up: %v", err) + log.Warnf("failed to bring service back up: %v", err) + } else { + state.needsRestoreUp = false + time.Sleep(time.Second * 3) } - time.Sleep(time.Second * 3) if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } - - return nil } func (s *serviceClient) collectDebugData( @@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { - return err - } + s.configureServiceForDebug(conn, state, params.enablePersistence) wg.Wait() progress.progressBar.Hide() @@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection( // Restore service to original state func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { + if state.needsRestoreUp { + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + log.Warnf("failed to restore up state: %v", err) + } else { + log.Info("Service state restored to up") + } + } + if state.wasDown { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("Failed to restore down state: %v", err) + log.Warnf("failed to restore down state: %v", err) } else { log.Info("Service state restored to down") } @@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { - log.Errorf("Failed to restore log level: %v", err) + log.Warnf("failed to restore log level: %v", err) } else { log.Info("Log level restored to original setting") } diff --git a/client/ui/event/event.go b/client/ui/event/event.go index 4d949416d..b8ed09a5c 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -107,12 +107,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { handlers := slices.Clone(e.handlers) e.mu.Unlock() - // critical events are always shown - if !enabled && event.Severity != proto.SystemEvent_CRITICAL { - return - } - - if event.UserMessage != "" { + if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) { title := e.getEventTitle(event) body := event.UserMessage id := event.Metadata["id"] diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 2216c8aeb..60a580dae 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -82,7 +82,7 @@ func (h *eventHandler) handleConnectClick() { go func() { defer connectCancel() - if err := h.client.menuUpClick(connectCtx, true); err != nil { + if err := h.client.menuUpClick(connectCtx); err != nil { st, ok := status.FromError(err) if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") @@ -100,8 +100,7 @@ func (h *eventHandler) handleConnectClick() { func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() - - h.client.exitNodeStates = []exitNodeState{} + h.client.cancelExitNodeRetry() if h.client.connectCancel != nil { log.Debugf("cancelling ongoing connect operation") @@ -212,9 +211,42 @@ func (h *eventHandler) handleGitHubClick() { } func (h *eventHandler) handleUpdateClick() { - if err := openURL(version.DownloadUrl()); err != nil { - log.Errorf("failed to open download URL: %v", err) + h.client.updateIndicationLock.Lock() + enforced := h.client.isEnforcedUpdate + h.client.updateIndicationLock.Unlock() + + if !enforced { + if err := openURL(version.DownloadUrl()); err != nil { + log.Errorf("failed to open download URL: %v", err) + } + return } + + // prevent blocking against a busy server + h.client.mUpdate.Disable() + go func() { + defer h.client.mUpdate.Enable() + conn, err := h.client.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get service client for update: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + + resp, err := conn.TriggerUpdate(h.client.ctx, &proto.TriggerUpdateRequest{}) + if err != nil { + log.Errorf("TriggerUpdate failed: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + if !resp.Success { + log.Errorf("TriggerUpdate failed: %s", resp.ErrorMsg) + _ = openURL(version.DownloadUrl()) + return + } + + log.Infof("update triggered via daemon") + }() } func (h *eventHandler) handleNetworksClick() { diff --git a/client/ui/network.go b/client/ui/network.go index 9a5ad7662..571e871bb 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "runtime" - "slices" "sort" "strings" "time" @@ -34,11 +33,6 @@ const ( type filter string -type exitNodeState struct { - id string - selected bool -} - func (s *serviceClient) showNetworksUI() { s.wNetworks = s.app.NewWindow("Networks") s.wNetworks.SetOnClosed(s.cancel) @@ -335,16 +329,75 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, s.updateNetworks(grid, f) } -func (s *serviceClient) updateExitNodes() { +// startExitNodeRefresh initiates exit node menu refresh after connecting. +// On Windows, TrayOpenedCh is not supported by the systray library, so we use +// a background poller to keep exit nodes in sync while connected. +// On macOS/Linux, TrayOpenedCh handles refreshes on each tray open. +func (s *serviceClient) startExitNodeRefresh() { + s.cancelExitNodeRetry() + + if runtime.GOOS == "windows" { + ctx, cancel := context.WithCancel(s.ctx) + s.exitNodeMu.Lock() + s.exitNodeRetryCancel = cancel + s.exitNodeMu.Unlock() + + go s.pollExitNodes(ctx) + } else { + go s.updateExitNodes() + } +} + +func (s *serviceClient) cancelExitNodeRetry() { + s.exitNodeMu.Lock() + if s.exitNodeRetryCancel != nil { + s.exitNodeRetryCancel() + s.exitNodeRetryCancel = nil + } + s.exitNodeMu.Unlock() +} + +// pollExitNodes periodically refreshes exit nodes while connected. +// Uses a short initial interval to catch routes from the management sync, +// then switches to a longer interval for ongoing updates. +func (s *serviceClient) pollExitNodes(ctx context.Context) { + // Initial fast polling to catch routes as they appear after connect. + for i := 0; i < 5; i++ { + if s.updateExitNodes() { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(2 * time.Second): + } + } + + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.updateExitNodes() + } + } +} + +// updateExitNodes fetches exit nodes from the daemon and recreates the menu. +// Returns true if exit nodes were found. +func (s *serviceClient) updateExitNodes() bool { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf("get client: %v", err) - return + return false } exitNodes, err := s.getExitNodes(conn) if err != nil { log.Errorf("get exit nodes: %v", err) - return + return false } s.exitNodeMu.Lock() @@ -354,34 +407,24 @@ func (s *serviceClient) updateExitNodes() { if len(s.mExitNodeItems) > 0 { s.mExitNode.Enable() - } else { - s.mExitNode.Disable() + return true } + + s.mExitNode.Disable() + return false } func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { - var exitNodeIDs []exitNodeState - for _, node := range exitNodes { - exitNodeIDs = append(exitNodeIDs, exitNodeState{ - id: node.ID, - selected: node.Selected, - }) - } - - sort.Slice(exitNodeIDs, func(i, j int) bool { - return exitNodeIDs[i].id < exitNodeIDs[j].id - }) - if slices.Equal(s.exitNodeStates, exitNodeIDs) { - log.Debug("Exit node menu already up to date") - return - } - for _, node := range s.mExitNodeItems { node.cancel() node.Hide() node.Remove() } s.mExitNodeItems = nil + if s.mExitNodeSeparator != nil { + s.mExitNodeSeparator.Remove() + s.mExitNodeSeparator = nil + } if s.mExitNodeDeselectAll != nil { s.mExitNodeDeselectAll.Remove() s.mExitNodeDeselectAll = nil @@ -413,34 +456,38 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { go s.handleChecked(ctx, node.ID, menuItem) } - s.exitNodeStates = exitNodeIDs - if showDeselectAll { - s.mExitNode.AddSeparator() - deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") - s.mExitNodeDeselectAll = deselectAllItem - go func() { - for { - _, ok := <-deselectAllItem.ClickedCh - if !ok { - // channel closed: exit the goroutine - return - } - exitNodes, err := s.handleExitNodeMenuDeselectAll() - if err != nil { - log.Warnf("failed to handle deselect all exit nodes: %v", err) - } else { - s.exitNodeMu.Lock() - s.recreateExitNodeMenu(exitNodes) - s.exitNodeMu.Unlock() - } - } - - }() + s.addExitNodeDeselectAll() } } +func (s *serviceClient) addExitNodeDeselectAll() { + sep := s.mExitNode.AddSubMenuItem("───────────────", "") + sep.Disable() + s.mExitNodeSeparator = sep + + deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") + s.mExitNodeDeselectAll = deselectAllItem + + go func() { + for { + _, ok := <-deselectAllItem.ClickedCh + if !ok { + return + } + exitNodes, err := s.handleExitNodeMenuDeselectAll() + if err != nil { + log.Warnf("failed to handle deselect all exit nodes: %v", err) + } else { + s.exitNodeMu.Lock() + s.recreateExitNodeMenu(exitNodes) + s.exitNodeMu.Unlock() + } + } + }() +} + func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) { ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout) defer cancel() diff --git a/client/ui/profile.go b/client/ui/profile.go index a38d8918a..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -397,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -411,7 +411,7 @@ type newProfileMenuArgs struct { profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -579,7 +579,7 @@ func (p *profileMenu) refresh() { connectCtx, connectCancel := context.WithCancel(p.ctx) p.serviceClient.connectCancel = connectCancel - if err := p.upClickCallback(connectCtx, false); err != nil { + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go index 76440d684..bf47ac434 100644 --- a/client/ui/quickactions.go +++ b/client/ui/quickactions.go @@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() { connCmd := connectCommand{ connectClient: func() error { - return s.menuUpClick(s.ctx, false) + return s.menuUpClick(s.ctx) }, } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 26022ffc7..d8e50ab6d 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -18,7 +18,6 @@ import ( "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/version" ) const ( @@ -350,7 +349,7 @@ func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) pbFullStatus := fullStatus.ToProto() - return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil + return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, nbstatus.ConvertOptions{}), nil } // createStatusMethod creates the status method that returns JSON diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 04155f72e..ce4df8394 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -7,6 +7,7 @@ import ( "net/netip" "os" "path" + "path/filepath" "strings" "time" @@ -70,6 +71,8 @@ type ServerConfig struct { DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"` Auth AuthConfig `yaml:"auth"` Store StoreConfig `yaml:"store"` + ActivityStore StoreConfig `yaml:"activityStore"` + AuthStore StoreConfig `yaml:"authStore"` ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"` } @@ -170,14 +173,17 @@ type RelaysConfig struct { type StoreConfig struct { Engine string `yaml:"engine"` EncryptionKey string `yaml:"encryptionKey"` - DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines + DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines + File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir) } // ReverseProxyConfig contains reverse proxy settings type ReverseProxyConfig struct { - TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` - TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` - TrustedPeers []string `yaml:"trustedPeers"` + TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` + TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` + TrustedPeers []string `yaml:"trustedPeers"` + AccessLogRetentionDays int `yaml:"accessLogRetentionDays"` + AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"` } // DefaultConfig returns a CombinedConfig with default values @@ -532,6 +538,74 @@ func stripSignalProtocol(uri string) string { return uri } +func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) { + var ttl time.Duration + if relays.CredentialsTTL != "" { + var err error + ttl, err = time.ParseDuration(relays.CredentialsTTL) + if err != nil { + return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err) + } + } + return &nbconfig.Relay{ + Addresses: relays.Addresses, + CredentialsTTL: util.Duration{Duration: ttl}, + Secret: relays.Secret, + }, nil +} + +// buildEmbeddedIdPConfig builds the embedded IdP configuration. +// authStore overrides auth.storage when set. +func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) { + authStorageType := mgmt.Auth.Storage.Type + authStorageDSN := c.Server.AuthStore.DSN + if c.Server.AuthStore.Engine != "" { + authStorageType = c.Server.AuthStore.Engine + } + if authStorageType == "" { + authStorageType = "sqlite3" + } + authStorageFile := "" + if authStorageType == "postgres" { + if authStorageDSN == "" { + return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres") + } + } else { + authStorageFile = path.Join(mgmt.DataDir, "idp.db") + if c.Server.AuthStore.File != "" { + authStorageFile = c.Server.AuthStore.File + if !filepath.IsAbs(authStorageFile) { + authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile) + } + } + } + + cfg := &idp.EmbeddedIdPConfig{ + Enabled: true, + Issuer: mgmt.Auth.Issuer, + LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, + SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, + Storage: idp.EmbeddedStorageConfig{ + Type: authStorageType, + Config: idp.EmbeddedStorageTypeConfig{ + File: authStorageFile, + DSN: authStorageDSN, + }, + }, + DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, + CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, + } + + if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" { + cfg.Owner = &idp.OwnerConfig{ + Email: mgmt.Auth.Owner.Email, + Hash: mgmt.Auth.Owner.Password, + } + } + + return cfg, nil +} + // ToManagementConfig converts CombinedConfig to management server config func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { mgmt := c.Management @@ -550,19 +624,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { // Build relay config var relayConfig *nbconfig.Relay if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" { - var ttl time.Duration - if mgmt.Relays.CredentialsTTL != "" { - var err error - ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL) - if err != nil { - return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err) - } - } - relayConfig = &nbconfig.Relay{ - Addresses: mgmt.Relays.Addresses, - CredentialsTTL: util.Duration{Duration: ttl}, - Secret: mgmt.Relays.Secret, + relay, err := buildRelayConfig(mgmt.Relays) + if err != nil { + return nil, err } + relayConfig = relay } // Build signal config @@ -581,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { // Build reverse proxy config reverseProxy := nbconfig.ReverseProxy{ - TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays, + AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours, } for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies { if prefix, err := netip.ParsePrefix(p); err == nil { @@ -598,31 +666,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { httpConfig := &nbconfig.HttpServerConfig{} // Build embedded IDP config (always enabled in combined server) - storageFile := mgmt.Auth.Storage.File - if storageFile == "" { - storageFile = path.Join(mgmt.DataDir, "idp.db") - } - - embeddedIdP := &idp.EmbeddedIdPConfig{ - Enabled: true, - Issuer: mgmt.Auth.Issuer, - LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, - SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, - Storage: idp.EmbeddedStorageConfig{ - Type: mgmt.Auth.Storage.Type, - Config: idp.EmbeddedStorageTypeConfig{ - File: storageFile, - }, - }, - DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, - CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, - } - - if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" { - embeddedIdP.Owner = &idp.OwnerConfig{ - Email: mgmt.Auth.Owner.Email, - Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text - } + embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt) + if err != nil { + return nil, err } // Set HTTP config fields for embedded IDP diff --git a/combined/cmd/root.go b/combined/cmd/root.go index b8ea7064c..db986b4d4 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/relay/healthcheck" relayServer "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/ws" sharedMetrics "github.com/netbirdio/netbird/shared/metrics" "github.com/netbirdio/netbird/shared/relay/auth" @@ -140,6 +141,23 @@ func initializeConfig() error { os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) } } + if file := config.Server.Store.File; file != "" { + os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file) + } + + if engine := config.Server.ActivityStore.Engine; engine != "" { + engineLower := strings.ToLower(engine) + if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" { + return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres") + } + os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower) + if dsn := config.Server.ActivityStore.DSN; dsn != "" { + os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn) + } + } + if file := config.Server.ActivityStore.File; file != "" { + os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file) + } log.Infof("Starting combined NetBird server") logConfig(config) @@ -476,9 +494,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { mgmt := cfg.Management - dnsDomain := mgmt.DnsDomain - singleAccModeDomain := dnsDomain - // Extract port from listen address _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) if err != nil { @@ -490,8 +505,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* mgmtSrv := mgmtServer.NewServer( &mgmtServer.Config{ NbConfig: mgmtConfig, - DNSDomain: dnsDomain, - MgmtSingleAccModeDomain: singleAccModeDomain, + DNSDomain: "", + MgmtSingleAccModeDomain: "", + AutoResolveDomains: true, MgmtPort: mgmtPort, MgmtMetricsPort: cfg.Server.MetricsPort, DisableMetrics: mgmt.DisableAnonymousMetrics, @@ -508,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) - var relayAcceptFn func(conn net.Conn) + var relayAcceptFn func(conn listener.Conn) if relaySrv != nil { relayAcceptFn = relaySrv.RelayAccept() } @@ -548,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re } // handleRelayWebSocket handles incoming WebSocket connections for the relay service -func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) { +func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) { acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -570,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func( return } - lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress) - if err != nil { - _ = wsConn.Close(websocket.StatusInternalError, "internal error") - return - } - log.Debugf("Relay WS client connected from: %s", rAddr) - conn := ws.NewConn(wsConn, lAddr, rAddr) + conn := ws.NewConn(wsConn, rAddr) acceptFn(conn) } @@ -668,8 +678,11 @@ func logEnvVars() { if strings.HasPrefix(env, "NB_") { key, _, _ := strings.Cut(env, "=") value := os.Getenv(key) - if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") { + keyLower := strings.ToLower(key) + if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") { value = maskSecret(value) + } else if strings.Contains(keyLower, "dsn") { + value = maskDSNPassword(value) } log.Infof(" %s=%s", key, value) found = true diff --git a/combined/cmd/token.go b/combined/cmd/token.go index 9393c6c46..550480062 100644 --- a/combined/cmd/token.go +++ b/combined/cmd/token.go @@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) } } + if file := cfg.Server.Store.File; file != "" { + os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file) + } datadir := cfg.Management.DataDir engine := types.Engine(cfg.Management.Store.Engine) diff --git a/combined/config.yaml.example b/combined/config.yaml.example index b3b38c5a9..dce658d89 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -103,6 +103,19 @@ server: engine: "sqlite" # sqlite, postgres, or mysql dsn: "" # Connection string for postgres or mysql encryptionKey: "" + # file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db) + + # Activity events store configuration (optional, defaults to sqlite in dataDir) + # activityStore: + # engine: "sqlite" # sqlite or postgres + # dsn: "" # Connection string for postgres + # file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db) + + # Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db) + # authStore: + # engine: "sqlite3" # sqlite3 or postgres + # dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable") + # file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db) # Reverse proxy settings (optional) # reverseProxy: diff --git a/flow/client/client.go b/flow/client/client.go index 318fcfe1e..8ad637974 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -14,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -26,11 +25,22 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +var ErrClientClosed = errors.New("client is closed") + +// minHealthyDuration is the minimum time a stream must survive before a failure +// resets the backoff timer. Streams that fail faster are considered unhealthy and +// should not reset backoff, so that MaxElapsedTime can eventually stop retries. +const minHealthyDuration = 5 * time.Second + type GRPCClient struct { realClient proto.FlowServiceClient clientConn *grpc.ClientConn stream proto.FlowService_EventsClient - streamMu sync.Mutex + target string + opts []grpc.DialOption + closed bool // prevent creating conn in the middle of the Close + receiving bool // prevent concurrent Receive calls + mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving } func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) { @@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`), ) - conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...) + target := parsedURL.Host + conn, err := grpc.NewClient(target, opts...) if err != nil { return nil, fmt.Errorf("creating new grpc client: %w", err) } @@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return &GRPCClient{ realClient: proto.NewFlowServiceClient(conn), clientConn: conn, + target: target, + opts: opts, }, nil } func (c *GRPCClient) Close() error { - c.streamMu.Lock() - defer c.streamMu.Unlock() - + c.mu.Lock() + c.closed = true c.stream = nil - if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) { + conn := c.clientConn + c.clientConn = nil + c.mu.Unlock() + + if conn == nil { + return nil + } + + if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) { return fmt.Errorf("close client connection: %w", err) } return nil } +func (c *GRPCClient) Send(event *proto.FlowEvent) error { + c.mu.Lock() + stream := c.stream + c.mu.Unlock() + + if stream == nil { + return errors.New("stream not initialized") + } + + if err := stream.Send(event); err != nil { + return fmt.Errorf("send flow event: %w", err) + } + + return nil +} + func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error { + c.mu.Lock() + if c.receiving { + c.mu.Unlock() + return errors.New("concurrent Receive calls are not supported") + } + c.receiving = true + c.mu.Unlock() + defer func() { + c.mu.Lock() + c.receiving = false + c.mu.Unlock() + }() + backOff := defaultBackoff(ctx, interval) operation := func() error { - if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { - return fmt.Errorf("receive: %w: %w", err, context.Canceled) - } + stream, err := c.establishStream(ctx) + if err != nil { + log.Errorf("failed to establish flow stream, retrying: %v", err) + return c.handleRetryableError(err, time.Time{}, backOff) + } + + streamStart := time.Now() + + if err := c.receive(stream, msgHandler); err != nil { log.Errorf("receive failed: %v", err) - return fmt.Errorf("receive: %w", err) + return c.handleRetryableError(err, streamStart, backOff) } return nil } @@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan return nil } -func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { - if c.clientConn.GetState() == connectivity.Shutdown { - return errors.New("connection to flow receiver has been shut down") +// handleRetryableError resets the backoff timer if the stream was healthy long +// enough and recreates the underlying ClientConn so that gRPC's internal +// subchannel backoff does not accumulate and compete with our own retry timer. +// A zero streamStart means the stream was never established. +func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error { + if isContextDone(err) { + return backoff.Permanent(err) } - stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true)) - if err != nil { - return fmt.Errorf("create event stream: %w", err) + var permErr *backoff.PermanentError + if errors.As(err, &permErr) { + return err } - err = stream.Send(&proto.FlowEvent{IsInitiator: true}) + // Reset the backoff so the next retry starts with a short delay instead of + // continuing the already-elapsed timer. Only do this if the stream was healthy + // long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime. + if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration { + backOff.Reset() + } + + if recreateErr := c.recreateConnection(); recreateErr != nil { + log.Errorf("recreate connection: %v", recreateErr) + return recreateErr + } + + log.Infof("connection recreated, retrying stream") + return fmt.Errorf("retrying after error: %w", err) +} + +func (c *GRPCClient) recreateConnection() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return backoff.Permanent(ErrClientClosed) + } + + conn, err := grpc.NewClient(c.target, c.opts...) if err != nil { - log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err) + c.mu.Unlock() + return fmt.Errorf("create new connection: %w", err) + } + + old := c.clientConn + c.clientConn = conn + c.realClient = proto.NewFlowServiceClient(conn) + c.stream = nil + c.mu.Unlock() + + _ = old.Close() + + return nil +} + +func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, backoff.Permanent(ErrClientClosed) + } + cl := c.realClient + c.mu.Unlock() + + // open stream outside the lock — blocking operation + stream, err := cl.Events(ctx) + if err != nil { + return nil, fmt.Errorf("create event stream: %w", err) + } + streamReady := false + defer func() { + if !streamReady { + _ = stream.CloseSend() + } + }() + + if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil { + return nil, fmt.Errorf("send initiator: %w", err) } if err = checkHeader(stream); err != nil { - return fmt.Errorf("check header: %w", err) + return nil, fmt.Errorf("check header: %w", err) } - c.streamMu.Lock() + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, backoff.Permanent(ErrClientClosed) + } c.stream = stream - c.streamMu.Unlock() + c.mu.Unlock() + streamReady = true - return c.receive(stream, msgHandler) + return stream, nil } func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error { for { msg, err := stream.Recv() if err != nil { - return fmt.Errorf("receive from stream: %w", err) + return err } if msg.IsInitiator { @@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error { func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 1, + RandomizationFactor: 0.5, Multiplier: 1.7, MaxInterval: interval / 2, MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months @@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff }, ctx) } -func (c *GRPCClient) Send(event *proto.FlowEvent) error { - c.streamMu.Lock() - stream := c.stream - c.streamMu.Unlock() - - if stream == nil { - return errors.New("stream not initialized") +func isContextDone(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true } - - if err := stream.Send(event); err != nil { - return fmt.Errorf("send flow event: %w", err) + if s, ok := status.FromError(err); ok { + return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded } - - return nil + return false } diff --git a/flow/client/client_test.go b/flow/client/client_test.go index efe01c003..55157acbc 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -2,8 +2,11 @@ package client_test import ( "context" + "encoding/binary" "errors" "net" + "sync" + "sync/atomic" "testing" "time" @@ -11,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" flow "github.com/netbirdio/netbird/flow/client" "github.com/netbirdio/netbird/flow/proto" @@ -18,21 +23,89 @@ import ( type testServer struct { proto.UnimplementedFlowServiceServer - events chan *proto.FlowEvent - acks chan *proto.FlowEventAck - grpcSrv *grpc.Server - addr string + events chan *proto.FlowEvent + acks chan *proto.FlowEventAck + grpcSrv *grpc.Server + addr string + listener *connTrackListener + closeStream chan struct{} // signal server to close the stream + handlerDone chan struct{} // signaled each time Events() exits + handlerStarted chan struct{} // signaled each time Events() begins +} + +// connTrackListener wraps a net.Listener to track accepted connections +// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM. +type connTrackListener struct { + net.Listener + mu sync.Mutex + conns []net.Conn +} + +func (l *connTrackListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.mu.Lock() + l.conns = append(l.conns, c) + l.mu.Unlock() + return c, nil +} + +// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR +// (error code 0x1) on every tracked connection. This produces the exact error: +// +// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR +// +// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload): +// +// Length (3 bytes): 0x000004 +// Type (1 byte): 0x03 (RST_STREAM) +// Flags (1 byte): 0x00 +// Stream ID (4 bytes): target stream (must have bit 31 clear) +// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR) +func (l *connTrackListener) connCount() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.conns) +} + +func (l *connTrackListener) sendRSTStream(streamID uint32) { + l.mu.Lock() + defer l.mu.Unlock() + + frame := make([]byte, 13) // 9-byte header + 4-byte payload + // Length = 4 (3 bytes, big-endian) + frame[0], frame[1], frame[2] = 0, 0, 4 + // Type = RST_STREAM (0x03) + frame[3] = 0x03 + // Flags = 0 + frame[4] = 0x00 + // Stream ID (4 bytes, big-endian, bit 31 reserved = 0) + binary.BigEndian.PutUint32(frame[5:9], streamID) + // Error Code = PROTOCOL_ERROR (0x1) + binary.BigEndian.PutUint32(frame[9:13], 0x1) + + for _, c := range l.conns { + _, _ = c.Write(frame) + } } func newTestServer(t *testing.T) *testServer { - listener, err := net.Listen("tcp", "127.0.0.1:0") + rawListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + listener := &connTrackListener{Listener: rawListener} + s := &testServer{ - events: make(chan *proto.FlowEvent, 100), - acks: make(chan *proto.FlowEventAck, 100), - grpcSrv: grpc.NewServer(), - addr: listener.Addr().String(), + events: make(chan *proto.FlowEvent, 100), + acks: make(chan *proto.FlowEventAck, 100), + grpcSrv: grpc.NewServer(), + addr: rawListener.Addr().String(), + listener: listener, + closeStream: make(chan struct{}, 1), + handlerDone: make(chan struct{}, 10), + handlerStarted: make(chan struct{}, 10), } proto.RegisterFlowServiceServer(s.grpcSrv, s) @@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer { } func (s *testServer) Events(stream proto.FlowService_EventsServer) error { + defer func() { + select { + case s.handlerDone <- struct{}{}: + default: + } + }() + err := stream.Send(&proto.FlowEventAck{IsInitiator: true}) if err != nil { return err } + select { + case s.handlerStarted <- struct{}{}: + default: + } + ctx, cancel := context.WithCancel(stream.Context()) defer cancel() @@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error { if err := stream.Send(ack); err != nil { return err } + case <-s.closeStream: + return status.Errorf(codes.Internal, "server closing stream") case <-ctx.Done(): return ctx.Err() } @@ -110,16 +197,13 @@ func TestReceive(t *testing.T) { assert.NoError(t, err, "failed to close flow") }) - receivedAcks := make(map[string]bool) + var ackCount atomic.Int32 receiveDone := make(chan struct{}) go func() { err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { if !msg.IsInitiator && len(msg.EventId) > 0 { - id := string(msg.EventId) - receivedAcks[id] = true - - if len(receivedAcks) >= 3 { + if ackCount.Add(1) >= 3 { close(receiveDone) } } @@ -130,7 +214,11 @@ func TestReceive(t *testing.T) { } }() - time.Sleep(500 * time.Millisecond) + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } for i := 0; i < 3; i++ { eventID := uuid.New().String() @@ -153,7 +241,7 @@ func TestReceive(t *testing.T) { t.Fatal("timeout waiting for acks to be processed") } - assert.Equal(t, 3, len(receivedAcks)) + assert.Equal(t, int32(3), ackCount.Load()) } func TestReceive_ContextCancellation(t *testing.T) { @@ -254,3 +342,195 @@ func TestSend(t *testing.T) { t.Fatal("timeout waiting for ack to be received by flow") } } + +func TestNewClient_PermanentClose(t *testing.T) { + server := newTestServer(t) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + err = client.Close() + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + }() + + select { + case err := <-done: + require.ErrorIs(t, err, flow.ErrClientClosed) + case <-time.After(2 * time.Second): + t.Fatal("Receive did not return after Close — stuck in retry loop") + } +} + +func TestNewClient_CloseVerify(t *testing.T) { + server := newTestServer(t) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + done := make(chan error, 1) + go func() { + done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + }() + + closeDone := make(chan struct{}, 1) + go func() { + _ = client.Close() + closeDone <- struct{}{} + }() + + select { + case err := <-done: + require.Error(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Receive did not return after Close — stuck in retry loop") + } + + select { + case <-closeDone: + return + case <-time.After(2 * time.Second): + t.Fatal("Close did not return — blocked in retry loop") + } + +} + +func TestClose_WhileReceiving(t *testing.T) { + server := newTestServer(t) + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + + ctx := context.Background() // no timeout — intentional + receiveDone := make(chan struct{}) + go func() { + _ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + return nil + }) + close(receiveDone) + }() + + // Wait for the server-side handler to confirm the stream is established. + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } + + closeDone := make(chan struct{}) + go func() { + _ = client.Close() + close(closeDone) + }() + + select { + case <-closeDone: + // Close returned — good + case <-time.After(2 * time.Second): + t.Fatal("Close blocked forever — Receive stuck in retry loop") + } + + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Fatal("Receive did not exit after Close") + } +} + +func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { + server := newTestServer(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + t.Cleanup(func() { + err := client.Close() + assert.NoError(t, err, "failed to close flow") + }) + + // Track acks received before and after server-side stream close + var ackCount atomic.Int32 + receivedFirst := make(chan struct{}) + receivedAfterReconnect := make(chan struct{}) + + go func() { + err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + if msg.IsInitiator || len(msg.EventId) == 0 { + return nil + } + n := ackCount.Add(1) + if n == 1 { + close(receivedFirst) + } + if n == 2 { + close(receivedAfterReconnect) + } + return nil + }) + if err != nil && !errors.Is(err, context.Canceled) { + t.Logf("receive error: %v", err) + } + }() + + // Wait for stream to be established, then send first ack + select { + case <-server.handlerStarted: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for stream to be established") + } + server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")} + + select { + case <-receivedFirst: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for first ack") + } + + // Snapshot connection count before injecting the fault. + connsBefore := server.listener.connCount() + + // Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection. + // gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated). + // Stream ID 1 is the client's first stream (our Events bidi stream). + // This produces the exact error the client sees in production: + // "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR" + server.listener.sendRSTStream(1) + + // Wait for the old Events() handler to fully exit so it can no longer + // drain s.acks and drop our injected ack on a broken stream. + select { + case <-server.handlerDone: + case <-time.After(5 * time.Second): + t.Fatal("old Events() handler did not exit after RST_STREAM") + } + + require.Eventually(t, func() bool { + return server.listener.connCount() > connsBefore + }, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM") + + server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")} + + select { + case <-receivedAfterReconnect: + // Client successfully reconnected and received ack after server-side stream close + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect") + } + + assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close") + assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)") +} diff --git a/go.mod b/go.mod index 4bcdbdc78..5172b1a78 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/netbirdio/netbird -go 1.25 - -toolchain go1.25.5 +go 1.25.5 require ( cunicu.li/go-rosenpass v0.4.0 @@ -15,28 +13,28 @@ require ( github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 github.com/rs/cors v1.8.0 - github.com/sirupsen/logrus v1.9.3 + github.com/sirupsen/logrus v1.9.4 github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.46.0 - golang.org/x/sys v0.39.0 + golang.org/x/crypto v0.49.0 + golang.org/x/sys v0.42.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.77.0 - google.golang.org/protobuf v1.36.10 - gopkg.in/natefinch/lumberjack.v2 v2.0.0 + google.golang.org/grpc v1.80.0 + google.golang.org/protobuf v1.36.11 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 - github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/awnumar/memguard v0.23.0 - github.com/aws/aws-sdk-go-v2 v1.36.3 - github.com/aws/aws-sdk-go-v2/config v1.29.14 - github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 + github.com/aws/aws-sdk-go-v2 v1.38.3 + github.com/aws/aws-sdk-go-v2/config v1.31.6 + github.com/aws/aws-sdk-go-v2/credentials v1.18.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 @@ -44,6 +42,8 @@ require ( github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-oidc/v3 v3.14.1 github.com/creack/pty v1.1.24 + github.com/crowdsecurity/crowdsec v1.7.7 + github.com/crowdsecurity/go-cs-bouncer v0.0.21 github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex/api/v2 v2.4.0 github.com/eko/gocache/lib/v4 v4.2.0 @@ -51,6 +51,7 @@ require ( github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/fsnotify/fsnotify v1.9.0 github.com/gliderlabs/ssh v0.3.8 + github.com/go-jose/go-jose/v4 v4.1.3 github.com/godbus/dbus/v5 v5.1.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/mock v1.6.0 @@ -61,15 +62,16 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 - github.com/hashicorp/go-version v1.6.0 + github.com/hashicorp/go-version v1.7.0 github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 + github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -103,37 +105,37 @@ require ( github.com/vmihailenco/msgpack/v5 v5.4.1 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 - go.opentelemetry.io/otel v1.38.0 - go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.38.0 - go.opentelemetry.io/otel/sdk/metric v1.38.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/prometheus v0.64.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 go.uber.org/mock v0.5.2 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.30.0 - golang.org/x/net v0.47.0 - golang.org/x/oauth2 v0.34.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.38.0 - golang.org/x/time v0.14.0 - google.golang.org/api v0.257.0 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.52.0 + golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.41.0 + golang.org/x/time v0.15.0 + google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 - gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c + gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 ) require ( - cloud.google.com/go/auth v0.17.0 // indirect + cloud.google.com/go/auth v0.20.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect dario.cat/mergo v1.0.1 // indirect - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/AppsFlyer/go-sundheit v0.6.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect @@ -144,37 +146,39 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/awnumar/memcall v0.4.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 // indirect github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect - github.com/aws/smithy-go v1.22.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/beevik/etree v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/crowdsecurity/go-cs-lib v0.0.25 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.0.1+incompatible // indirect - github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/ebitengine/purego v0.8.2 // indirect + github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect @@ -184,43 +188,59 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect - github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-ldap/ldap/v3 v3.4.12 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect + github.com/go-openapi/analysis v0.23.0 // indirect + github.com/go-openapi/errors v0.22.2 // indirect + github.com/go-openapi/jsonpointer v0.21.1 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/loads v0.22.0 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/go-openapi/swag v0.23.1 // indirect + github.com/go-openapi/validate v0.24.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect + github.com/goccy/go-yaml v1.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect - github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.21.0 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/huandu/xstrings v1.5.0 // indirect + github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jonboulle/clockwork v0.5.0 // indirect + github.com/josharian/intern v1.0.0 // indirect github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect @@ -228,6 +248,7 @@ require ( github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -239,7 +260,8 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect - github.com/nxadm/tail v1.4.8 // indirect + github.com/nxadm/tail v1.4.11 // indirect + github.com/oklog/ulid v1.3.1 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect @@ -249,40 +271,43 @@ require ( github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/procfs v0.16.1 // indirect - github.com/russellhaering/goxmldsig v1.5.0 // indirect + github.com/prometheus/common v0.67.5 // indirect + github.com/prometheus/otlptranslator v1.0.0 // indirect + github.com/prometheus/procfs v0.19.2 // indirect + github.com/russellhaering/goxmldsig v1.6.0 // indirect github.com/rymdport/portal v0.4.2 // indirect - github.com/shirou/gopsutil/v4 v4.25.1 // indirect - github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/shirou/gopsutil/v4 v4.25.8 // indirect + github.com/shoenig/go-m1cpu v0.2.1 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/tklauser/go-sysconf v0.3.14 // indirect - github.com/tklauser/numcpus v0.8.0 // indirect + github.com/tklauser/go-sysconf v0.3.15 // indirect + github.com/tklauser/numcpus v0.10.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect + go.mongodb.org/mongo-driver v1.17.9 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect go.uber.org/multierr v1.11.0 // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/tools v0.39.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/tools v0.42.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 diff --git a/go.sum b/go.sum index 1bd9396bb..9293ce73b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= -cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= +cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA= +cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= @@ -9,8 +9,8 @@ cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4= @@ -34,56 +34,56 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= -github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo= -github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= -github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= -github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14= -github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= -github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= -github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2 v1.38.3 h1:B6cV4oxnMs45fql4yRH+/Po/YU+597zgWqvDpYMturk= +github.com/aws/aws-sdk-go-v2 v1.38.3/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00= +github.com/aws/aws-sdk-go-v2/config v1.31.6 h1:a1t8fXY4GT4xjyJExz4knbuoxSCacB5hT/WgtfPyLjo= +github.com/aws/aws-sdk-go-v2/config v1.31.6/go.mod h1:5ByscNi7R+ztvOGzeUaIu49vkMk2soq5NaH5PYe33MQ= +github.com/aws/aws-sdk-go-v2/credentials v1.18.10 h1:xdJnXCouCx8Y0NncgoptztUocIYLKeQxrCgN6x9sdhg= +github.com/aws/aws-sdk-go-v2/credentials v1.18.10/go.mod h1:7tQk08ntj914F/5i9jC4+2HQTAuJirq7m1vZVIhEkWs= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 h1:wbjnrrMnKew78/juW7I2BtKQwa1qlf6EjQgS69uYY14= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6/go.mod h1:AtiqqNrDioJXuUgz3+3T0mBWN7Hro2n9wll2zRUc0ww= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 h1:uF68eJA6+S9iVr9WgX1NaRGyQ/6MdIyc4JNUo6TN1FA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6/go.mod h1:qlPeVZCGPiobx8wb1ft0GHT5l+dc6ldnwInDFaMvC7Y= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 h1:pa1DEC6JoI0zduhZePp3zmhWvk/xxm4NB8Hy/Tlsgos= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6/go.mod h1:gxEjPebnhWGJoaDdtDkA0JX46VRg1wcTHYe63OfX5pE= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 h1:lguz0bmOoGzozP9XfRJR1QIayEYo+2vP/No3OfLF0pU= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0/go.mod h1:iu6FSzgt+M2/x3Dk8zhycdIcHjEFb36IS8HVUVFoMg0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91LiqT1nbvzDukyqAlCv89ZmwaHw/ZFlFZg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 h1:R0tNFJqfjHL3900cqhXuwQ+1K4G0xc9Yf8EDbFXCKEw= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6/go.mod h1:y/7sDdu+aJvPtGXr4xYosdpq9a6T9Z0jkXfugmti0rI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 h1:hncKj/4gR+TPauZgTAsxOxNcvBayhUlYZ6LO/BYiQ30= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6/go.mod h1:OiIh45tp6HdJDDJGnja0mw8ihQGz3VGrUflLqSL0SmM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 h1:LHS1YAIJXJ4K9zS+1d/xa9JAA9sL2QyXIQCQFQW/X08= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6/go.mod h1:c9PCiTEuh0wQID5/KqA32J+HAgZxN9tOGXKCiYJjTZI= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 h1:nEXUSAwyUfLTgnc9cxlDWy637qsq4UWwp3sNAfl0Z3Y= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6/go.mod h1:HGzIULx4Ge3Do2V0FaiYKcyKzOqwrhUZgCI77NisswQ= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= -github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 h1:tWUG+4wZqdMl/znThEk9tcCy8tTMxq8dW0JTgamohrY= -github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2/go.mod h1:U5SNqwhXB3Xe6F47kXvWihPl/ilGaEDe8HD/50Z9wxc= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= -github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= -github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= -github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 h1:ETkfWcXP2KNPLecaDa++5bsQhCRa5M5sLUJa5DWYIIg= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3/go.mod h1:+/3ZTqoYb3Ur7DObD00tarKMLMuKg8iqz5CHEanqTnw= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 h1:8OLZnVJPvjnrxEwHFg9hVUof/P4sibH+Ea4KKuqAGSg= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.1/go.mod h1:27M3BpVi0C02UiQh1w9nsBEit6pLhlaH3NHna6WUbDE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 h1:gKWSTnqudpo8dAxqBqZnDoDWCiEh/40FziUjr/mo6uA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2/go.mod h1:x7+rkNmRoEN1U13A6JE2fXne9EWyJy54o3n6d4mGaXQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 h1:YZPjhyaGzhDQEvsffDEcpycq49nl7fiGcfJTIo8BszI= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.2/go.mod h1:2dIN8qhQfv37BdUYGgEC8Q3tteM3zFxTI1MLO2O3J3c= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE= github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -101,6 +101,8 @@ github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+Y github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= @@ -120,11 +122,18 @@ github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHf github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/crowdsecurity/crowdsec v1.7.7 h1:sduZN763iXsrZodocWDrsR//7nLeffGu+RVkkIsbQkE= +github.com/crowdsecurity/crowdsec v1.7.7/go.mod h1:L1HLGPDnBYCcY+yfSFnuBbQ1G9DHEJN9c+Kevv9F+4Q= +github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy18CzYvYab8UJDU= +github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA= +github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4= +github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A= github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -133,12 +142,12 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0= github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= -github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= -github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= +github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw= github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M= github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw= @@ -157,6 +166,7 @@ github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko= github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs= @@ -189,6 +199,24 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= +github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= +github.com/go-openapi/errors v0.22.2 h1:rdxhzcBUazEcGccKqbY1Y7NS8FDcMyIRr0934jrYnZg= +github.com/go-openapi/errors v0.22.2/go.mod h1:+n/5UdIqdVnLIJ6Q9Se8HNGUXYaY6CN8ImWzfi/Gzp0= +github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic= +github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= +github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= +github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= +github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= +github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58= +github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= @@ -205,10 +233,14 @@ github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= +github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -232,6 +264,7 @@ github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl76 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -239,6 +272,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -250,10 +285,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= -github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= -github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= +github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -278,11 +313,13 @@ github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PU github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= -github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= +github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY= +github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -293,6 +330,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= @@ -317,6 +356,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= @@ -328,8 +369,10 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= +github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -348,6 +391,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= +github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= @@ -355,6 +400,8 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -378,6 +425,8 @@ github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa1 github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -404,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -417,10 +466,13 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY= +github.com/nxadm/tail v1.4.11/go.mod h1:OTaG3NK980DZzxbRq6lEuzgU+mug70nY11sMd4JXXHc= github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI= github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc= github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -441,8 +493,8 @@ github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq5 github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= -github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= @@ -480,8 +532,9 @@ github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= @@ -489,10 +542,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos= +github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= @@ -503,23 +558,25 @@ github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw= -github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk= +github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks= +github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= -github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs= -github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI= -github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= +github.com/shirou/gopsutil/v4 v4.25.8 h1:NnAsw9lN7587WHxjJA9ryDnqhJpFH6A+wagYWTOH970= +github.com/shirou/gopsutil/v4 v4.25.8/go.mod h1:q9QdMmfAOVIw7a+eF86P7ISEU6ka+NLgkUxlopV4RwI= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= -github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4= +github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= +github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= @@ -568,11 +625,11 @@ github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40= github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= -github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= -github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= +github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= +github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= -github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= -github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= +github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= +github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= @@ -601,28 +658,30 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= +go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU= +go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= -go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= -go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -633,8 +692,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -648,10 +707,10 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -666,8 +725,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -686,11 +745,11 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= -golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -702,8 +761,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -723,8 +782,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -738,8 +797,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -752,8 +811,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -765,10 +824,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= -golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -780,8 +839,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -792,19 +851,19 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= -google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY= +google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= -google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= -google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -815,8 +874,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -826,8 +885,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8 gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= -gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= @@ -852,5 +911,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA= -gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ= +gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= +gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= diff --git a/idp/dex/config.go b/idp/dex/config.go index 57f832406..7f5300f14 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -5,7 +5,10 @@ import ( "encoding/json" "fmt" "log/slog" + "net/url" "os" + "strconv" + "strings" "time" "golang.org/x/crypto/bcrypt" @@ -167,20 +170,66 @@ type Connector struct { } // ToStorageConnector converts a Connector to storage.Connector type. +// It maps custom connector types (e.g., "zitadel", "entra") to Dex-native types +// and augments the config with OIDC defaults when needed. func (c *Connector) ToStorageConnector() (storage.Connector, error) { - data, err := json.Marshal(c.Config) + dexType, augmentedConfig := mapConnectorToDex(c.Type, c.Config) + + data, err := json.Marshal(augmentedConfig) if err != nil { return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err) } return storage.Connector{ ID: c.ID, - Type: c.Type, + Type: dexType, Name: c.Name, Config: data, }, nil } +// mapConnectorToDex maps custom connector types to Dex-native types and applies +// OIDC defaults. This ensures static connectors from config files or env vars +// are stored with types that Dex can open. +func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { + switch connType { + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + return "oidc", applyOIDCDefaults(connType, config) + default: + return connType, config + } +} + +// applyOIDCDefaults clones the config map, sets common OIDC defaults, +// and applies provider-specific overrides. +func applyOIDCDefaults(connType string, config map[string]interface{}) map[string]interface{} { + augmented := make(map[string]interface{}, len(config)+4) + for k, v := range config { + augmented[k] = v + } + setDefault(augmented, "scopes", []string{"openid", "profile", "email"}) + setDefault(augmented, "insecureEnableGroups", true) + setDefault(augmented, "insecureSkipEmailVerified", true) + + switch connType { + case "zitadel": + setDefault(augmented, "getUserInfo", true) + case "entra": + setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) + case "okta", "pocketid": + augmented["scopes"] = []string{"openid", "profile", "email", "groups"} + } + + return augmented +} + +// setDefault sets a key in the map only if it doesn't already exist. +func setDefault(m map[string]interface{}, key string, value interface{}) { + if _, ok := m[key]; !ok { + m[key] = value + } +} + // StorageConfig is a configuration that can create a storage. type StorageConfig interface { Open(logger *slog.Logger) (storage.Storage, error) @@ -195,11 +244,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) { return nil, fmt.Errorf("sqlite3 storage requires 'file' config") } return (&sql.SQLite3{File: file}).Open(logger) + case "postgres": + dsn, _ := s.Config["dsn"].(string) + if dsn == "" { + return nil, fmt.Errorf("postgres storage requires 'dsn' config") + } + pg, err := parsePostgresDSN(dsn) + if err != nil { + return nil, fmt.Errorf("invalid postgres DSN: %w", err) + } + return pg.Open(logger) default: return nil, fmt.Errorf("unsupported storage type: %s", s.Type) } } +// parsePostgresDSN parses a DSN into a sql.Postgres config. +// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable) +// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values. +func parsePostgresDSN(dsn string) (*sql.Postgres, error) { + var params map[string]string + var err error + + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + params, err = parsePostgresURI(dsn) + } else { + params, err = parsePostgresKeyValue(dsn) + } + if err != nil { + return nil, err + } + + host := params["host"] + if host == "" { + host = "localhost" + } + + var port uint16 = 5432 + if p, ok := params["port"]; ok && p != "" { + v, err := strconv.ParseUint(p, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid port %q: %w", p, err) + } + if v == 0 { + return nil, fmt.Errorf("invalid port %q: must be non-zero", p) + } + port = uint16(v) + } + + dbname := params["dbname"] + if dbname == "" { + return nil, fmt.Errorf("dbname is required in DSN") + } + + pg := &sql.Postgres{ + NetworkDB: sql.NetworkDB{ + Host: host, + Port: port, + Database: dbname, + User: params["user"], + Password: params["password"], + }, + } + + if sslMode := params["sslmode"]; sslMode != "" { + switch sslMode { + case "disable", "allow", "prefer", "require", "verify-ca", "verify-full": + pg.SSL.Mode = sslMode + default: + return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode) + } + } + + return pg, nil +} + +// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs. +func parsePostgresURI(dsn string) (map[string]string, error) { + u, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf("invalid postgres URI: %w", err) + } + + params := make(map[string]string) + + if u.User != nil { + params["user"] = u.User.Username() + if p, ok := u.User.Password(); ok { + params["password"] = p + } + } + if u.Hostname() != "" { + params["host"] = u.Hostname() + } + if u.Port() != "" { + params["port"] = u.Port() + } + + dbname := strings.TrimPrefix(u.Path, "/") + if dbname != "" { + params["dbname"] = dbname + } + + for k, v := range u.Query() { + if len(v) > 0 { + params[k] = v[0] + } + } + + return params, nil +} + +// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values +// (e.g., password='my pass' host=localhost). +func parsePostgresKeyValue(dsn string) (map[string]string, error) { + params := make(map[string]string) + s := strings.TrimSpace(dsn) + + for s != "" { + eqIdx := strings.IndexByte(s, '=') + if eqIdx < 0 { + break + } + key := strings.TrimSpace(s[:eqIdx]) + + value, rest, err := parseDSNValue(s[eqIdx+1:]) + if err != nil { + return nil, fmt.Errorf("%w for key %q", err, key) + } + + params[key] = value + s = strings.TrimSpace(rest) + } + + return params, nil +} + +// parseDSNValue parses the next value from a libpq key=value string positioned after the '='. +// It returns the parsed value and the remaining unparsed string. +func parseDSNValue(s string) (value, rest string, err error) { + if len(s) > 0 && s[0] == '\'' { + return parseQuotedDSNValue(s[1:]) + } + // Unquoted value: read until whitespace. + idx := strings.IndexAny(s, " \t\n") + if idx < 0 { + return s, "", nil + } + return s[:idx], s[idx:], nil +} + +// parseQuotedDSNValue parses a single-quoted value starting after the opening quote. +// Libpq uses ” to represent a literal single quote inside quoted values. +func parseQuotedDSNValue(s string) (value, rest string, err error) { + var buf strings.Builder + for len(s) > 0 { + if s[0] == '\'' { + if len(s) > 1 && s[1] == '\'' { + buf.WriteByte('\'') + s = s[2:] + continue + } + return buf.String(), s[1:], nil + } + buf.WriteByte(s[0]) + s = s[1:] + } + return "", "", fmt.Errorf("unterminated quoted value") +} + // Validate validates the configuration func (c *YAMLConfig) Validate() error { if c.Issuer == "" { diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 68fe48486..24aed1b99 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -4,6 +4,7 @@ package dex import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "log/slog" @@ -19,10 +20,13 @@ import ( "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/sql" + jose "github.com/go-jose/go-jose/v4" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc" + + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) // Config matches what management/internals/server/server.go expects @@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string { } return issuer + "/auth" } + +// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks. +// This avoids HTTP round-trips when the embedded IDP is co-located with the management server. +// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic: +// SigningKeyPub first, then all VerificationKeys, serialized via go-jose. +func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) { + keys, err := p.storage.GetKeys(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get keys from storage: %w", err) + } + + if keys.SigningKeyPub == nil { + return nil, fmt.Errorf("no public keys found in storage") + } + + // Build the key set exactly as Dex's localSigner.ValidationKeys does: + // signing key first, then all verification (rotated) keys. + joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1) + joseKeys = append(joseKeys, *keys.SigningKeyPub) + for _, vk := range keys.VerificationKeys { + if vk.PublicKey != nil { + joseKeys = append(joseKeys, *vk.PublicKey) + } + } + + // Serialize through go-jose (same as Dex's handlePublicKeys handler) + // then deserialize into our Jwks type, so the JSON field mapping is identical + // to what the /keys HTTP endpoint would return. + joseSet := jose.JSONWebKeySet{Keys: joseKeys} + data, err := json.Marshal(joseSet) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWKS: %w", err) + } + + jwks := &nbjwt.Jwks{} + if err := json.Unmarshal(data, jwks); err != nil { + return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err) + } + + jwks.ExpiresInTime = keys.NextRotation + + return jwks, nil +} diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go index bd2f676fb..4ed89fd2e 100644 --- a/idp/dex/provider_test.go +++ b/idp/dex/provider_test.go @@ -2,11 +2,14 @@ package dex import ( "context" + "encoding/json" "log/slog" "os" "path/filepath" "testing" + "github.com/dexidp/dex/storage" + sqllib "github.com/dexidp/dex/storage/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -197,6 +200,295 @@ enablePasswordDB: true t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID) } +// openTestStorage creates a SQLite storage in the given directory for testing. +func openTestStorage(t *testing.T, tmpDir string) storage.Storage { + t.Helper() + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + stor, err := (&sqllib.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger) + require.NoError(t, err) + return stor +} + +func TestStaticConnectors_CreatedFromYAML(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: My OIDC Provider + config: + issuer: https://accounts.example.com + clientID: test-client-id + clientSecret: test-client-secret + redirectURI: http://localhost:5556/dex/callback +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + // Open storage and run initializeStorage directly (avoids Dex server + // trying to dial the OIDC issuer) + stor := openTestStorage(t, tmpDir) + defer stor.Close() + + err = initializeStorage(ctx, stor, yamlConfig) + require.NoError(t, err) + + // Verify connector was created in storage + conn, err := stor.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "my-oidc", conn.ID) + assert.Equal(t, "My OIDC Provider", conn.Name) + assert.Equal(t, "oidc", conn.Type) + + // Verify config fields were serialized correctly + var configMap map[string]interface{} + err = json.Unmarshal(conn.Config, &configMap) + require.NoError(t, err) + assert.Equal(t, "https://accounts.example.com", configMap["issuer"]) + assert.Equal(t, "test-client-id", configMap["clientID"]) +} + +func TestStaticConnectors_UpdatedOnRestart(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-update-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First: load config with initial connector + yamlContent1 := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + dbFile + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: Original Name + config: + issuer: https://accounts.example.com + clientID: original-client-id + clientSecret: original-secret +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent1), 0644) + require.NoError(t, err) + + yamlConfig1, err := LoadConfig(configPath) + require.NoError(t, err) + + stor := openTestStorage(t, tmpDir) + err = initializeStorage(ctx, stor, yamlConfig1) + require.NoError(t, err) + + // Verify initial state + conn, err := stor.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "Original Name", conn.Name) + + var configMap1 map[string]interface{} + err = json.Unmarshal(conn.Config, &configMap1) + require.NoError(t, err) + assert.Equal(t, "original-client-id", configMap1["clientID"]) + + // Close storage to simulate restart + stor.Close() + + // Second: load updated config against the same DB + yamlContent2 := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + dbFile + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: Updated Name + config: + issuer: https://accounts.example.com + clientID: updated-client-id + clientSecret: updated-secret +` + err = os.WriteFile(configPath, []byte(yamlContent2), 0644) + require.NoError(t, err) + + yamlConfig2, err := LoadConfig(configPath) + require.NoError(t, err) + + stor2 := openTestStorage(t, tmpDir) + defer stor2.Close() + + err = initializeStorage(ctx, stor2, yamlConfig2) + require.NoError(t, err) + + // Verify connector was updated, not duplicated + allConnectors, err := stor2.ListConnectors(ctx) + require.NoError(t, err) + + nonLocalCount := 0 + for _, c := range allConnectors { + if c.ID != "local" { + nonLocalCount++ + } + } + assert.Equal(t, 1, nonLocalCount, "connector should be updated, not duplicated") + + conn2, err := stor2.GetConnector(ctx, "my-oidc") + require.NoError(t, err) + assert.Equal(t, "Updated Name", conn2.Name) + + var configMap2 map[string]interface{} + err = json.Unmarshal(conn2.Config, &configMap2) + require.NoError(t, err) + assert.Equal(t, "updated-client-id", configMap2["clientID"]) +} + +func TestStaticConnectors_MultipleConnectors(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-multi-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +connectors: +- type: oidc + id: my-oidc + name: My OIDC Provider + config: + issuer: https://accounts.example.com + clientID: oidc-client-id + clientSecret: oidc-secret +- type: google + id: my-google + name: Google Login + config: + clientID: google-client-id + clientSecret: google-secret +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + stor := openTestStorage(t, tmpDir) + defer stor.Close() + + err = initializeStorage(ctx, stor, yamlConfig) + require.NoError(t, err) + + allConnectors, err := stor.ListConnectors(ctx) + require.NoError(t, err) + + // Build a map for easier assertion + connByID := make(map[string]storage.Connector) + for _, c := range allConnectors { + connByID[c.ID] = c + } + + // Verify both static connectors exist + oidcConn, ok := connByID["my-oidc"] + require.True(t, ok, "oidc connector should exist") + assert.Equal(t, "My OIDC Provider", oidcConn.Name) + assert.Equal(t, "oidc", oidcConn.Type) + + var oidcConfig map[string]interface{} + err = json.Unmarshal(oidcConn.Config, &oidcConfig) + require.NoError(t, err) + assert.Equal(t, "oidc-client-id", oidcConfig["clientID"]) + + googleConn, ok := connByID["my-google"] + require.True(t, ok, "google connector should exist") + assert.Equal(t, "Google Login", googleConn.Name) + assert.Equal(t, "google", googleConn.Type) + + var googleConfig map[string]interface{} + err = json.Unmarshal(googleConn.Config, &googleConfig) + require.NoError(t, err) + assert.Equal(t, "google-client-id", googleConfig["clientID"]) + + // Verify local connector still exists alongside them (enablePasswordDB: true) + localConn, ok := connByID["local"] + require.True(t, ok, "local connector should exist") + assert.Equal(t, "local", localConn.Type) +} + +func TestStaticConnectors_EmptyList(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-static-conn-empty-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + provider, err := NewProviderFromYAML(ctx, yamlConfig) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + // No static connectors configured, so ListConnectors should return empty + connectors, err := provider.ListConnectors(ctx) + require.NoError(t, err) + assert.Empty(t, connectors) + + // But local connector should still exist + localConn, err := provider.Storage().GetConnector(ctx, "local") + require.NoError(t, err) + assert.Equal(t, "local", localConn.ID) +} + func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) { ctx := context.Background() diff --git a/infrastructure_files/getting-started-with-dex.sh b/infrastructure_files/getting-started-with-dex.sh index a14c6134e..5e605f19c 100755 --- a/infrastructure_files/getting-started-with-dex.sh +++ b/infrastructure_files/getting-started-with-dex.sh @@ -172,8 +172,11 @@ init_environment() { echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" echo "" echo "Login with the following credentials:" - echo "Email: admin@$NETBIRD_DOMAIN" | tee .env - echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env + install -m 600 /dev/null .env + printf 'Email: admin@%s\nPassword: %s\n' \ + "$NETBIRD_DOMAIN" "$NETBIRD_ADMIN_PASSWORD" >> .env + echo "Email: admin@$NETBIRD_DOMAIN" + echo "Password: $NETBIRD_ADMIN_PASSWORD" echo "" echo "Dex admin UI is not available (Dex has no built-in UI)." echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex" diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 09c5225ad..f503cbeac 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -563,8 +563,11 @@ initEnvironment() { echo -e "\nDone!\n" echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" echo "Login with the following credentials:" - echo "Username: $ZITADEL_ADMIN_USERNAME" | tee .env - echo "Password: $ZITADEL_ADMIN_PASSWORD" | tee -a .env + install -m 600 /dev/null .env + printf 'Username: %s\nPassword: %s\n' \ + "$ZITADEL_ADMIN_USERNAME" "$ZITADEL_ADMIN_PASSWORD" >> .env + echo "Username: $ZITADEL_ADMIN_USERNAME" + echo "Password: $ZITADEL_ADMIN_PASSWORD" } renderCaddyfile() { diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 7fd87ee8e..08da48264 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,41 +182,20 @@ read_enable_proxy() { return 0 } -read_proxy_domain() { - local suggested_proxy="proxy.${BASE_DOMAIN}" - +read_enable_crowdsec() { echo "" > /dev/stderr - echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr - echo "to avoid TLS certificate conflicts." > /dev/stderr - echo "" > /dev/stderr - echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr - echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr - echo "" > /dev/stderr - echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr - read -r READ_PROXY_DOMAIN < /dev/tty + echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr + echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr + echo "and blocks known malicious sources before they reach your services." > /dev/stderr + echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr + echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr + read -r CHOICE < /dev/tty - if [[ -z "$READ_PROXY_DOMAIN" ]]; then - echo "The proxy domain cannot be empty." > /dev/stderr - read_proxy_domain - return + if [[ "$CHOICE" =~ ^[Yy]$ ]]; then + echo "true" + else + echo "false" fi - - if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null - if [[ $? -eq 0 ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo "$READ_PROXY_DOMAIN" return 0 } @@ -334,8 +313,11 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" - PROXY_DOMAIN="" PROXY_TOKEN="" + + # CrowdSec configuration + ENABLE_CROWDSEC="false" + CROWDSEC_BOUNCER_KEY="" return 0 } @@ -365,7 +347,7 @@ configure_reverse_proxy() { TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) if [[ "$ENABLE_PROXY" == "true" ]]; then - PROXY_DOMAIN=$(read_proxy_domain) + ENABLE_CROWDSEC=$(read_enable_crowdsec) fi fi @@ -396,7 +378,7 @@ check_existing_installation() { echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "You can use the following commands:" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" - echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" + echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/" echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." exit 1 fi @@ -417,6 +399,9 @@ generate_configuration_files() { echo "NB_PROXY_TOKEN=placeholder" >> proxy.env # TCP ServersTransport for PROXY protocol v2 to the proxy backend render_traefik_dynamic > traefik-dynamic.yaml + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + mkdir -p crowdsec + fi fi ;; 1) @@ -459,8 +444,12 @@ start_services_and_show_instructions() { if [[ "$ENABLE_PROXY" == "true" ]]; then # Phase 1: Start core services (without proxy) + local core_services="traefik dashboard netbird-server" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + core_services="$core_services crowdsec" + fi echo "Starting core services..." - $DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server + $DOCKER_COMPOSE_COMMAND up -d $core_services sleep 3 wait_management_proxy traefik @@ -480,7 +469,33 @@ start_services_and_show_instructions() { echo "Proxy token created successfully." - # Generate proxy.env with the token + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + echo "Registering CrowdSec bouncer..." + local cs_retries=0 + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do + cs_retries=$((cs_retries + 1)) + if [[ $cs_retries -ge 30 ]]; then + echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr + echo "You can register a bouncer manually later with:" > /dev/stderr + echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr + ENABLE_CROWDSEC="false" + break + fi + sleep 2 + done + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \ + cscli bouncers add netbird-proxy -o raw 2>/dev/null) + if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then + echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr + ENABLE_CROWDSEC="false" + else + echo "CrowdSec bouncer registered." + fi + fi + fi + render_proxy_env > proxy.env # Start proxy service @@ -567,11 +582,25 @@ render_docker_compose_traefik_builtin() { # Generate proxy service section and Traefik dynamic config if enabled local proxy_service="" local proxy_volumes="" + local crowdsec_service="" + local crowdsec_volumes="" local traefik_file_provider="" local traefik_dynamic_volume="" if [[ "$ENABLE_PROXY" == "true" ]]; then traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"' traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro" + + local proxy_depends=" + netbird-server: + condition: service_started" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + proxy_depends=" + netbird-server: + condition: service_started + crowdsec: + condition: service_healthy" + fi + proxy_service=" # NetBird Proxy - exposes internal resources to the internet proxy: @@ -581,8 +610,7 @@ render_docker_compose_traefik_builtin() { - 51820:51820/udp restart: unless-stopped networks: [netbird] - depends_on: - - netbird-server + depends_on:${proxy_depends} env_file: - ./proxy.env volumes: @@ -605,6 +633,35 @@ render_docker_compose_traefik_builtin() { " proxy_volumes=" netbird_proxy_certs:" + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + crowdsec_service=" + crowdsec: + image: crowdsecurity/crowdsec:v1.7.7 + container_name: netbird-crowdsec + restart: unless-stopped + networks: [netbird] + environment: + COLLECTIONS: crowdsecurity/linux + volumes: + - ./crowdsec:/etc/crowdsec + - crowdsec_db:/var/lib/crowdsec/data + healthcheck: + test: ["CMD", "cscli", "lapi", "status"] + interval: 10s + timeout: 5s + retries: 15 + labels: + - traefik.enable=false + logging: + driver: \"json-file\" + options: + max-size: \"500m\" + max-file: \"2\" +" + crowdsec_volumes=" + crowdsec_db:" + fi fi cat <" + echo " Get your enrollment key at: https://app.crowdsec.net" + echo "" + fi fi return 0 } diff --git a/infrastructure_files/observability/grafana/dashboards/management.json b/infrastructure_files/observability/grafana/dashboards/management.json index 95983603f..f116a8bde 100644 --- a/infrastructure_files/observability/grafana/dashboards/management.json +++ b/infrastructure_files/observability/grafana/dashboards/management.json @@ -302,7 +302,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_account_peer_meta_update_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_account_peer_meta_update_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -410,7 +410,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -426,7 +426,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -443,7 +443,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -545,7 +545,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -561,7 +561,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -578,7 +578,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -694,7 +694,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -710,7 +710,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -727,7 +727,7 @@ }, "disableTextWrap": false, "editorMode": "code", - "expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", + "expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))", "format": "heatmap", "fullMetaSearch": false, "hide": false, @@ -841,7 +841,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -853,7 +853,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -866,7 +866,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -963,7 +963,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -975,7 +975,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -988,7 +988,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1085,7 +1085,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1097,7 +1097,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1110,7 +1110,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1221,7 +1221,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_authenticate_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_authenticate_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1317,7 +1317,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_get_account_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_get_account_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1413,7 +1413,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "rate(management_idp_update_user_meta_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", + "expr": "rate(management_idp_update_user_meta_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])", "instant": false, "legendFormat": "{{cluster}}/{{environment}}/{{job}}", "range": true, @@ -1523,7 +1523,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)", "instant": false, "legendFormat": "{{method}}", "range": true, @@ -1619,7 +1619,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)", "instant": false, "legendFormat": "{{method}}", "range": true, @@ -1715,7 +1715,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1727,7 +1727,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1740,7 +1740,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1837,7 +1837,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "instant": false, "legendFormat": "p50", "range": true, @@ -1849,7 +1849,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p90", @@ -1862,7 +1862,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", + "expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))", "hide": false, "instant": false, "legendFormat": "p99", @@ -1963,7 +1963,7 @@ "uid": "${datasource}" }, "editorMode": "code", - "expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)", + "expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)", "hide": false, "instant": false, "legendFormat": "{{method}}-{{exported_endpoint}}", @@ -3222,7 +3222,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "sum by(le) (increase(management_grpc_updatechannel_queue_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", + "expr": "sum by(le) (increase(management_grpc_updatechannel_queue_length_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, @@ -3323,7 +3323,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", + "expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))", "format": "heatmap", "fullMetaSearch": false, "includeNullMetadata": true, diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index b2b65f47a..4b414df6f 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -63,6 +63,8 @@ type Controller struct { expNewNetworkMap bool expNewNetworkMapAIDs map[string]struct{} + + compactedNetworkMap bool } type bufferUpdate struct { @@ -85,6 +87,17 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App newNetworkMapBuilder = false } + compactedNetworkMap := true + compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) + parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) + if err != nil && len(compactedEnv) > 0 { + log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) + } + if err == nil && !parsedCompactedNmap { + log.WithContext(ctx).Info("disabling compacted mode") + compactedNetworkMap = false + } + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") expIDs := make(map[string]struct{}, len(ids)) for _, id := range ids { @@ -108,6 +121,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, expNewNetworkMapAIDs: expIDs, + + compactedNetworkMap: compactedNetworkMap, } } @@ -230,9 +245,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin var remotePeerNetworkMap *types.NetworkMap - if c.experimentalNetworkMap(accountID) { + switch { + case c.experimentalNetworkMap(accountID): remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { + case c.compactedNetworkMap: + remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + default: remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } @@ -355,9 +373,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe var remotePeerNetworkMap *types.NetworkMap - if c.experimentalNetworkMap(accountId) { + switch { + case c.experimentalNetworkMap(accountId): remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { + case c.compactedNetworkMap: + remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + default: remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } @@ -479,7 +500,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr } else { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers()) + groupIDToUserIDs := account.GetActiveGroupUsers() + if c.compactedNetworkMap { + networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -854,7 +880,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N account.InjectProxyPolicies(ctx) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + groupIDToUserIDs := account.GetActiveGroupUsers() + if c.compactedNetworkMap { + networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 2f796a5d1..d3f8f44ff 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -154,9 +154,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs return err } - eventsToStore = append(eventsToStore, func() { - m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } return nil }) @@ -210,7 +212,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee }, } - _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false) + _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true) if err != nil { return fmt.Errorf("failed to create proxy peer: %w", err) } diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 019cb634a..f2ecfd5f9 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -1,6 +1,7 @@ package accesslogs import ( + "maps" "net" "net/netip" "time" @@ -10,20 +11,34 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// AccessLogProtocol identifies the transport protocol of an access log entry. +type AccessLogProtocol string + +const ( + AccessLogProtocolHTTP AccessLogProtocol = "http" + AccessLogProtocolTCP AccessLogProtocol = "tcp" + AccessLogProtocolUDP AccessLogProtocol = "udp" +) + type AccessLogEntry struct { - ID string `gorm:"primaryKey"` - AccountID string `gorm:"index"` - ServiceID string `gorm:"index"` - Timestamp time.Time `gorm:"index"` - GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"` - Method string `gorm:"index"` - Host string `gorm:"index"` - Path string `gorm:"index"` - Duration time.Duration `gorm:"index"` - StatusCode int `gorm:"index"` - Reason string - UserId string `gorm:"index"` - AuthMethodUsed string `gorm:"index"` + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + ServiceID string `gorm:"index"` + Timestamp time.Time `gorm:"index"` + GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"` + SubdivisionCode string + Method string `gorm:"index"` + Host string `gorm:"index"` + Path string `gorm:"index"` + Duration time.Duration `gorm:"index"` + StatusCode int `gorm:"index"` + Reason string + UserId string `gorm:"index"` + AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` + Protocol AccessLogProtocol `gorm:"index"` + Metadata map[string]string `gorm:"serializer:json"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -39,17 +54,25 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.UserId = serviceLog.GetUserId() a.AuthMethodUsed = serviceLog.GetAuthMechanism() a.AccountID = serviceLog.GetAccountId() + a.BytesUpload = serviceLog.GetBytesUpload() + a.BytesDownload = serviceLog.GetBytesDownload() + a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) + a.Metadata = maps.Clone(serviceLog.GetMetadata()) if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { - if ip, err := netip.ParseAddr(sourceIP); err == nil { - a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice()) + if addr, err := netip.ParseAddr(sourceIP); err == nil { + addr = addr.Unmap() + a.GeoLocation.ConnectionIP = net.IP(addr.AsSlice()) } } - if !serviceLog.GetAuthSuccess() { - a.Reason = "Authentication failed" - } else if serviceLog.GetResponseCode() >= 400 { - a.Reason = "Request failed" + // Only set reason for HTTP entries. L4 entries have no auth or status code. + if a.Protocol == "" || a.Protocol == AccessLogProtocolHTTP { + if !serviceLog.GetAuthSuccess() { + a.Reason = "Authentication failed" + } else if serviceLog.GetResponseCode() >= 400 { + a.Reason = "Request failed" + } } } @@ -86,20 +109,41 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { cityName = &a.GeoLocation.CityName } + var subdivisionCode *string + if a.SubdivisionCode != "" { + subdivisionCode = &a.SubdivisionCode + } + + var protocol *string + if a.Protocol != "" { + p := string(a.Protocol) + protocol = &p + } + + var metadata *map[string]string + if len(a.Metadata) > 0 { + metadata = &a.Metadata + } + return &api.ProxyAccessLog{ - Id: a.ID, - ServiceId: a.ServiceID, - Timestamp: a.Timestamp, - Method: a.Method, - Host: a.Host, - Path: a.Path, - DurationMs: int(a.Duration.Milliseconds()), - StatusCode: a.StatusCode, - SourceIp: sourceIP, - Reason: reason, - UserId: userID, - AuthMethodUsed: authMethod, - CountryCode: countryCode, - CityName: cityName, + Id: a.ID, + ServiceId: a.ServiceID, + Timestamp: a.Timestamp, + Method: a.Method, + Host: a.Host, + Path: a.Path, + DurationMs: int(a.Duration.Milliseconds()), + StatusCode: a.StatusCode, + SourceIp: sourceIP, + Reason: reason, + UserId: userID, + AuthMethodUsed: authMethod, + CountryCode: countryCode, + CityName: cityName, + SubdivisionCode: subdivisionCode, + BytesUpload: a.BytesUpload, + BytesDownload: a.BytesDownload, + Protocol: protocol, + Metadata: metadata, } } diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index e7fba7bed..59d7704eb 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac logEntry.GeoLocation.CountryCode = location.Country.ISOCode logEntry.GeoLocation.CityName = location.City.Names.En logEntry.GeoLocation.GeoNameID = location.City.GeonameID + if len(location.Subdivisions) > 0 { + logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode + } } } @@ -103,13 +106,23 @@ func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays in // StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { - if retentionDays <= 0 { - log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative") + if retentionDays < 0 { + log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is negative") return } + if retentionDays == 0 { + retentionDays = 7 + log.WithContext(ctx).Debugf("no retention days specified for access log cleanup, defaulting to %d days", retentionDays) + } else { + log.WithContext(ctx).Debugf("access log retention period set to %d days", retentionDays) + } + if cleanupIntervalHours <= 0 { cleanupIntervalHours = 24 + log.WithContext(ctx).Debugf("no cleanup interval specified for access log cleanup, defaulting to %d hours", cleanupIntervalHours) + } else { + log.WithContext(ctx).Debugf("access log cleanup interval set to %d hours", cleanupIntervalHours) } cleanupCtx, cancel := context.WithCancel(ctx) diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go index 8fadef85f..11bf60829 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go @@ -121,7 +121,7 @@ func TestCleanupWithExactBoundary(t *testing.T) { } func TestStartPeriodicCleanup(t *testing.T) { - t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) { + t.Run("periodic cleanup disabled with negative retention", func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -135,7 +135,7 @@ func TestStartPeriodicCleanup(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - manager.StartPeriodicCleanup(ctx, 0, 1) + manager.StartPeriodicCleanup(ctx, -1, 1) time.Sleep(100 * time.Millisecond) diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index da3432626..f65e31a07 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -14,4 +14,27 @@ type Domain struct { TargetCluster string // The proxy cluster this domain should be validated against Type Type `gorm:"-"` Validated bool + // SupportsCustomPorts is populated at query time for free domains from the + // proxy cluster capabilities. Not persisted. + SupportsCustomPorts *bool `gorm:"-"` + // RequireSubdomain is populated at query time. When true, the domain + // cannot be used bare and a subdomain label must be prepended. Not persisted. + RequireSubdomain *bool `gorm:"-"` + // SupportsCrowdSec is populated at query time from proxy cluster capabilities. + // Not persisted. + SupportsCrowdSec *bool `gorm:"-"` +} + +// EventMeta returns activity event metadata for a domain +func (d *Domain) EventMeta() map[string]any { + return map[string]any{ + "domain": d.Domain, + "target_cluster": d.TargetCluster, + "validated": d.Validated, + } +} + +func (d *Domain) Copy() *Domain { + dCopy := *d + return &dCopy } diff --git a/management/internals/modules/reverseproxy/domain/interface.go b/management/internals/modules/reverseproxy/domain/interface.go index d40e9b637..a4bba5841 100644 --- a/management/internals/modules/reverseproxy/domain/interface.go +++ b/management/internals/modules/reverseproxy/domain/interface.go @@ -9,4 +9,5 @@ type Manager interface { CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error ValidateDomain(ctx context.Context, accountID, userID, domainID string) + GetClusterDomains() []string } diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 2fbcdd5b8..4493ef0ad 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -42,10 +42,13 @@ func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainToApi(d *domain.Domain) api.ReverseProxyDomain { resp := api.ReverseProxyDomain{ - Domain: d.Domain, - Id: d.ID, - Type: domainTypeToApi(d.Type), - Validated: d.Validated, + Domain: d.Domain, + Id: d.ID, + Type: domainTypeToApi(d.Type), + Validated: d.Validated, + SupportsCustomPorts: d.SupportsCustomPorts, + RequireSubdomain: d.RequireSubdomain, + SupportsCrowdsec: d.SupportsCrowdSec, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/domain_test.go b/management/internals/modules/reverseproxy/domain/manager/domain_test.go new file mode 100644 index 000000000..523920a99 --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/domain_test.go @@ -0,0 +1,172 @@ +package manager + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" +) + +func TestExtractClusterFromFreeDomain(t *testing.T) { + clusters := []string{"eu1.proxy.netbird.io", "us1.proxy.netbird.io"} + + tests := []struct { + name string + domain string + wantOK bool + wantVal string + }{ + { + name: "subdomain of cluster matches", + domain: "myapp.eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "deep subdomain of cluster matches", + domain: "foo.bar.eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "bare cluster domain matches", + domain: "eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "unrelated domain does not match", + domain: "example.com", + wantOK: false, + }, + { + name: "partial suffix does not match", + domain: "fakeu1.proxy.netbird.io", + wantOK: false, + }, + { + name: "second cluster matches", + domain: "app.us1.proxy.netbird.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := ExtractClusterFromFreeDomain(tc.domain, clusters) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantVal, cluster) + } + }) + } +} + +func TestExtractClusterFromCustomDomains(t *testing.T) { + customDomains := []*domain.Domain{ + {Domain: "example.com", TargetCluster: "eu1.proxy.netbird.io"}, + {Domain: "proxy.corp.io", TargetCluster: "us1.proxy.netbird.io"}, + } + + tests := []struct { + name string + domain string + wantOK bool + wantVal string + }{ + { + name: "subdomain of custom domain matches", + domain: "app.example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "bare custom domain matches", + domain: "example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "deep subdomain of custom domain matches", + domain: "a.b.example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "subdomain of multi-level custom domain matches", + domain: "app.proxy.corp.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + { + name: "bare multi-level custom domain matches", + domain: "proxy.corp.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + { + name: "unrelated domain does not match", + domain: "other.com", + wantOK: false, + }, + { + name: "partial suffix does not match custom domain", + domain: "fakeexample.com", + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantVal, cluster) + } + }) + } +} + +func TestExtractClusterFromCustomDomains_OverlappingDomains(t *testing.T) { + customDomains := []*domain.Domain{ + {Domain: "example.com", TargetCluster: "cluster-generic"}, + {Domain: "app.example.com", TargetCluster: "cluster-app"}, + } + + tests := []struct { + name string + domain string + wantVal string + }{ + { + name: "exact match on more specific domain", + domain: "app.example.com", + wantVal: "cluster-app", + }, + { + name: "subdomain of more specific domain", + domain: "api.app.example.com", + wantVal: "cluster-app", + }, + { + name: "subdomain of generic domain", + domain: "other.example.com", + wantVal: "cluster-generic", + }, + { + name: "bare generic domain", + domain: "example.com", + wantVal: "cluster-generic", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains) + assert.True(t, ok) + assert.Equal(t, tc.wantVal, cluster) + }) + } +} diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 1125f428f..2c4c1372e 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -9,6 +9,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -27,25 +29,28 @@ type store interface { DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error } -type proxyURLProvider interface { - GetConnectedProxyURLs() []string +type proxyManager interface { + GetActiveClusterAddresses(ctx context.Context) ([]string, error) + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool } type Manager struct { store store validator domain.Validator - proxyURLProvider proxyURLProvider + proxyManager proxyManager permissionsManager permissions.Manager + accountManager account.Manager } -func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager { +func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager { return Manager{ - store: store, - proxyURLProvider: proxyURLProvider, - validator: domain.Validator{ - Resolver: net.DefaultResolver, - }, + store: store, + proxyManager: proxyMgr, + validator: domain.Validator{Resolver: net.DefaultResolver}, permissionsManager: permissionsManager, + accountManager: accountManager, } } @@ -67,31 +72,46 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d // Add connected proxy clusters as free domains. // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList := m.proxyURLAllowList() - log.WithFields(log.Fields{ + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) + return nil, err + } + log.WithContext(ctx).WithFields(log.Fields{ "accountID": accountID, "proxyAllowList": allowList, }).Debug("getting domains with proxy allow list") for _, cluster := range allowList { - ret = append(ret, &domain.Domain{ + d := &domain.Domain{ Domain: cluster, AccountID: accountID, Type: domain.TypeFree, Validated: true, - }) + } + d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) + d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) + d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster) + ret = append(ret, d) } // Add custom domains. for _, d := range domains { - ret = append(ret, &domain.Domain{ + cd := &domain.Domain{ ID: d.ID, Domain: d.Domain, AccountID: accountID, TargetCluster: d.TargetCluster, Type: domain.TypeCustom, Validated: d.Validated, - }) + } + if d.TargetCluster != "" { + cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) + cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster) + } + // Custom domains never require a subdomain by default since + // the account owns them and should be able to use the bare domain. + ret = append(ret, cd) } return ret, nil @@ -107,7 +127,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName } // Verify the target cluster is in the available clusters - allowList := m.proxyURLAllowList() + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) + } clusterValid := false for _, cluster := range allowList { if cluster == targetCluster { @@ -129,6 +152,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName if err != nil { return d, fmt.Errorf("create domain in store: %w", err) } + + m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta()) + return d, nil } @@ -141,10 +167,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s return status.NewPermissionDeniedError() } + d, err := m.store.GetCustomDomain(ctx, accountID, domainID) + if err != nil { + return fmt.Errorf("get domain from store: %w", err) + } + if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil { // TODO: check for "no records" type error. Because that is a success condition. return fmt.Errorf("delete domain from store: %w", err) } + + m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta()) + return nil } @@ -211,6 +245,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID }).WithError(err).Error("update custom domain in store") return } + + m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta()) } else { log.WithFields(log.Fields{ "accountID": accountID, @@ -221,21 +257,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID } } -// proxyURLAllowList retrieves a list of currently connected proxies and -// their URLs -func (m Manager) proxyURLAllowList() []string { - var reverseProxyAddresses []string - if m.proxyURLProvider != nil { - reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs() +// GetClusterDomains returns a list of proxy cluster domains. +func (m Manager) GetClusterDomains() []string { + if m.proxyManager == nil { + return nil } - return reverseProxyAddresses + addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background()) + if err != nil { + return nil + } + return addresses } // DeriveClusterFromDomain determines the proxy cluster for a given domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList := m.proxyURLAllowList() + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) + } if len(allowList) == 0 { return "", fmt.Errorf("no proxy clusters available") } @@ -257,13 +298,19 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } -func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) { - for _, customDomain := range customDomains { - if strings.HasSuffix(domain, "."+customDomain.Domain) { - return customDomain.TargetCluster, true +func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { + bestCluster := "" + bestLen := -1 + for _, cd := range customDomains { + if serviceDomain != cd.Domain && !strings.HasSuffix(serviceDomain, "."+cd.Domain) { + continue + } + if l := len(cd.Domain); l > bestLen { + bestLen = l + bestCluster = cd.TargetCluster } } - return "", false + return bestCluster, bestLen >= 0 } // ExtractClusterFromFreeDomain extracts the cluster address from a free domain. @@ -271,7 +318,7 @@ func extractClusterFromCustomDomains(domain string, customDomains []*domain.Doma // It matches the domain suffix against available clusters and returns the matching cluster. func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { for _, cluster := range availableClusters { - if strings.HasSuffix(domain, "."+cluster) { + if domain == cluster || strings.HasSuffix(domain, "."+cluster) { return cluster, true } } diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go deleted file mode 100644 index 535705a37..000000000 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ /dev/null @@ -1,539 +0,0 @@ -package manager - -import ( - "context" - "fmt" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" - nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" - "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -const unknownHostPlaceholder = "unknown" - -// ClusterDeriver derives the proxy cluster from a domain. -type ClusterDeriver interface { - DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) -} - -type managerImpl struct { - store store.Store - accountManager account.Manager - permissionsManager permissions.Manager - proxyGRPCServer *nbgrpc.ProxyServiceServer - clusterDeriver ClusterDeriver -} - -// NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { - return &managerImpl{ - store: store, - accountManager: accountManager, - permissionsManager: permissionsManager, - proxyGRPCServer: proxyGRPCServer, - clusterDeriver: clusterDeriver, - } -} - -func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, fmt.Errorf("failed to get services: %w", err) - } - - for _, service := range services { - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - } - - return services, nil -} - -func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { - for _, target := range service.Targets { - switch target.TargetType { - case reverseproxy.TargetTypePeer: - peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) - if err != nil { - log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err) - target.Host = unknownHostPlaceholder - continue - } - target.Host = peer.IP.String() - case reverseproxy.TargetTypeHost: - resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) - if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) - target.Host = unknownHostPlaceholder - continue - } - target.Host = resource.Prefix.Addr().String() - case reverseproxy.TargetTypeDomain: - resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) - if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) - target.Host = unknownHostPlaceholder - continue - } - target.Host = resource.Domain - case reverseproxy.TargetTypeSubnet: - // For subnets we do not do any lookups on the resource - default: - return fmt.Errorf("unknown target type: %s", target.TargetType) - } - } - return nil -} - -func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) - if err != nil { - return nil, fmt.Errorf("failed to get service: %w", err) - } - - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - return service, nil -} - -func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { - return nil, err - } - - if err := m.persistNewService(ctx, accountID, service); err != nil { - return nil, err - } - - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) - - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) - - m.accountManager.UpdateAccountPeers(ctx, accountID) - - return service, nil -} - -func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { - if m.clusterDeriver != nil { - proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain) - return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err) - } - service.ProxyCluster = proxyCluster - } - - service.AccountID = accountID - service.InitNewRecord() - - if err := service.Auth.HashSecrets(); err != nil { - return fmt.Errorf("hash secrets: %w", err) - } - - keyPair, err := sessionkey.GenerateKeyPair() - if err != nil { - return fmt.Errorf("generate session keys: %w", err) - } - service.SessionPrivateKey = keyPair.PrivateKey - service.SessionPublicKey = keyPair.PublicKey - - return nil -} - -func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { - return err - } - - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - - if err := transaction.CreateService(ctx, service); err != nil { - return fmt.Errorf("failed to create service: %w", err) - } - - return nil - }) -} - -func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { - existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) - if err != nil { - if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing service: %w", err) - } - return nil - } - - if existingService != nil && existingService.ID != excludeServiceID { - return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain) - } - - return nil -} - -func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - if err := service.Auth.HashSecrets(); err != nil { - return nil, fmt.Errorf("hash secrets: %w", err) - } - - updateInfo, err := m.persistServiceUpdate(ctx, accountID, service) - if err != nil { - return nil, err - } - - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta()) - - if err := m.replaceHostByLookup(ctx, accountID, service); err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - - m.sendServiceUpdateNotifications(service, updateInfo) - m.accountManager.UpdateAccountPeers(ctx, accountID) - - return service, nil -} - -type serviceUpdateInfo struct { - oldCluster string - domainChanged bool - serviceEnabledChanged bool -} - -func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { - var updateInfo serviceUpdateInfo - - err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) - if err != nil { - return err - } - - updateInfo.oldCluster = existingService.ProxyCluster - updateInfo.domainChanged = existingService.Domain != service.Domain - - if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { - return err - } - } else { - service.ProxyCluster = existingService.ProxyCluster - } - - m.preserveExistingAuthSecrets(service, existingService) - m.preserveServiceMetadata(service, existingService) - updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled - - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - - if err := transaction.UpdateService(ctx, service); err != nil { - return fmt.Errorf("update service: %w", err) - } - - return nil - }) - - return &updateInfo, err -} - -func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { - return err - } - - if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain) - } else { - service.ProxyCluster = newCluster - } - } - - return nil -} - -func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { - if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && - existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && - service.Auth.PasswordAuth.Password == "" { - service.Auth.PasswordAuth = existingService.Auth.PasswordAuth - } - - if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled && - existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled && - service.Auth.PinAuth.Pin == "" { - service.Auth.PinAuth = existingService.Auth.PinAuth - } -} - -func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) { - service.Meta = existingService.Meta - service.SessionPrivateKey = existingService.SessionPrivateKey - service.SessionPublicKey = existingService.SessionPublicKey -} - -func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() - - switch { - case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) - case !service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) - case service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) - default: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) - } -} - -// validateTargetReferences checks that all target IDs reference existing peers or resources in the account. -func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error { - for _, target := range targets { - switch target.TargetType { - case reverseproxy.TargetTypePeer: - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) - } - case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain: - if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) - } - } - } - return nil -} - -func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - - var service *reverseproxy.Service - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - var err error - service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) - if err != nil { - return err - } - - if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { - return fmt.Errorf("failed to delete service: %w", err) - } - - return nil - }) - if err != nil { - return err - } - - m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) - - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) - - m.accountManager.UpdateAccountPeers(ctx, accountID) - - return nil -} - -// SetCertificateIssuedAt sets the certificate issued timestamp to the current time. -// Call this when receiving a gRPC notification that the certificate was issued. -func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) - if err != nil { - return fmt.Errorf("failed to get service: %w", err) - } - - service.Meta.CertificateIssuedAt = time.Now() - - if err = transaction.UpdateService(ctx, service); err != nil { - return fmt.Errorf("failed to update service certificate timestamp: %w", err) - } - - return nil - }) -} - -// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) -func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) - if err != nil { - return fmt.Errorf("failed to get service: %w", err) - } - - service.Meta.Status = string(status) - - if err = transaction.UpdateService(ctx, service); err != nil { - return fmt.Errorf("failed to update service status: %w", err) - } - - return nil - }) -} - -func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error { - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) - if err != nil { - return fmt.Errorf("failed to get service: %w", err) - } - - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) - - m.accountManager.UpdateAccountPeers(ctx, accountID) - - return nil -} - -func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { - services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return fmt.Errorf("failed to get services: %w", err) - } - - for _, service := range services { - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) - } - - return nil -} - -func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { - services, err := m.store.GetServices(ctx, store.LockingStrengthNone) - if err != nil { - return nil, fmt.Errorf("failed to get services: %w", err) - } - - for _, service := range services { - err = m.replaceHostByLookup(ctx, service.AccountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - } - - return services, nil -} - -func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) - if err != nil { - return nil, fmt.Errorf("failed to get service: %w", err) - } - - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - - return service, nil -} - -func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { - services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, fmt.Errorf("failed to get services: %w", err) - } - - for _, service := range services { - err = m.replaceHostByLookup(ctx, accountID, service) - if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) - } - } - - return services, nil -} - -func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { - target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) - if err != nil { - if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - return "", nil - } - return "", fmt.Errorf("failed to get service target by resource ID: %w", err) - } - - if target == nil { - return "", nil - } - - return target.ServiceID, nil -} diff --git a/management/internals/modules/reverseproxy/manager/manager_test.go b/management/internals/modules/reverseproxy/manager/manager_test.go deleted file mode 100644 index 266b0066f..000000000 --- a/management/internals/modules/reverseproxy/manager/manager_test.go +++ /dev/null @@ -1,375 +0,0 @@ -package manager - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -func TestInitializeServiceForCreate(t *testing.T) { - ctx := context.Background() - accountID := "test-account" - - t.Run("successful initialization without cluster deriver", func(t *testing.T) { - mgr := &managerImpl{ - clusterDeriver: nil, - } - - service := &reverseproxy.Service{ - Domain: "example.com", - Auth: reverseproxy.AuthConfig{}, - } - - err := mgr.initializeServiceForCreate(ctx, accountID, service) - - assert.NoError(t, err) - assert.Equal(t, accountID, service.AccountID) - assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver") - assert.NotEmpty(t, service.ID, "service ID should be initialized") - assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated") - assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated") - }) - - t.Run("verifies session keys are different", func(t *testing.T) { - mgr := &managerImpl{ - clusterDeriver: nil, - } - - service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}} - service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}} - - err1 := mgr.initializeServiceForCreate(ctx, accountID, service1) - err2 := mgr.initializeServiceForCreate(ctx, accountID, service2) - - assert.NoError(t, err1) - assert.NoError(t, err2) - assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique") - assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique") - }) -} - -func TestCheckDomainAvailable(t *testing.T) { - ctx := context.Background() - accountID := "test-account" - - tests := []struct { - name string - domain string - excludeServiceID string - setupMock func(*store.MockStore) - expectedError bool - errorType status.Type - }{ - { - name: "domain available - not found", - domain: "available.com", - excludeServiceID: "", - setupMock: func(ms *store.MockStore) { - ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "available.com"). - Return(nil, status.Errorf(status.NotFound, "not found")) - }, - expectedError: false, - }, - { - name: "domain already exists", - domain: "exists.com", - excludeServiceID: "", - setupMock: func(ms *store.MockStore) { - ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil) - }, - expectedError: true, - errorType: status.AlreadyExists, - }, - { - name: "domain exists but excluded (same ID)", - domain: "exists.com", - excludeServiceID: "service-123", - setupMock: func(ms *store.MockStore) { - ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) - }, - expectedError: false, - }, - { - name: "domain exists with different ID", - domain: "exists.com", - excludeServiceID: "service-456", - setupMock: func(ms *store.MockStore) { - ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) - }, - expectedError: true, - errorType: status.AlreadyExists, - }, - { - name: "store error (non-NotFound)", - domain: "error.com", - excludeServiceID: "", - setupMock: func(ms *store.MockStore) { - ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "error.com"). - Return(nil, errors.New("database error")) - }, - expectedError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - tt.setupMock(mockStore) - - mgr := &managerImpl{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) - - if tt.expectedError { - require.Error(t, err) - if tt.errorType != 0 { - sErr, ok := status.FromError(err) - require.True(t, ok, "error should be a status error") - assert.Equal(t, tt.errorType, sErr.Type()) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestCheckDomainAvailable_EdgeCases(t *testing.T) { - ctx := context.Background() - accountID := "test-account" - - t.Run("empty domain", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, ""). - Return(nil, status.Errorf(status.NotFound, "not found")) - - mgr := &managerImpl{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") - - assert.NoError(t, err) - }) - - t.Run("empty exclude ID with existing service", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "test.com"). - Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil) - - mgr := &managerImpl{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") - - assert.Error(t, err) - sErr, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, status.AlreadyExists, sErr.Type()) - }) - - t.Run("nil existing service with nil error", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "nil.com"). - Return(nil, nil) - - mgr := &managerImpl{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") - - assert.NoError(t, err) - }) -} - -func TestPersistNewService(t *testing.T) { - ctx := context.Background() - accountID := "test-account" - - t.Run("successful service creation with no targets", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - service := &reverseproxy.Service{ - ID: "service-123", - Domain: "new.com", - Targets: []*reverseproxy.Target{}, - } - - // Mock ExecuteInTransaction to execute the function immediately - mockStore.EXPECT(). - ExecuteInTransaction(ctx, gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { - // Create another mock for the transaction - txMock := store.NewMockStore(ctrl) - txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "new.com"). - Return(nil, status.Errorf(status.NotFound, "not found")) - txMock.EXPECT(). - CreateService(ctx, service). - Return(nil) - - return fn(txMock) - }) - - mgr := &managerImpl{store: mockStore} - err := mgr.persistNewService(ctx, accountID, service) - - assert.NoError(t, err) - }) - - t.Run("domain already exists", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStore := store.NewMockStore(ctrl) - service := &reverseproxy.Service{ - ID: "service-123", - Domain: "existing.com", - Targets: []*reverseproxy.Target{}, - } - - mockStore.EXPECT(). - ExecuteInTransaction(ctx, gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { - txMock := store.NewMockStore(ctrl) - txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "existing.com"). - Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil) - - return fn(txMock) - }) - - mgr := &managerImpl{store: mockStore} - err := mgr.persistNewService(ctx, accountID, service) - - require.Error(t, err) - sErr, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, status.AlreadyExists, sErr.Type()) - }) -} -func TestPreserveExistingAuthSecrets(t *testing.T) { - mgr := &managerImpl{} - - t.Run("preserve password when empty", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ - Enabled: true, - Password: "hashed-password", - }, - }, - } - - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ - Enabled: true, - Password: "", - }, - }, - } - - mgr.preserveExistingAuthSecrets(updated, existing) - - assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth) - }) - - t.Run("preserve pin when empty", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PinAuth: &reverseproxy.PINAuthConfig{ - Enabled: true, - Pin: "hashed-pin", - }, - }, - } - - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PinAuth: &reverseproxy.PINAuthConfig{ - Enabled: true, - Pin: "", - }, - }, - } - - mgr.preserveExistingAuthSecrets(updated, existing) - - assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth) - }) - - t.Run("do not preserve when password is provided", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ - Enabled: true, - Password: "old-password", - }, - }, - } - - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ - Enabled: true, - Password: "new-password", - }, - }, - } - - mgr.preserveExistingAuthSecrets(updated, existing) - - assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password) - assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth) - }) -} - -func TestPreserveServiceMetadata(t *testing.T) { - mgr := &managerImpl{} - - existing := &reverseproxy.Service{ - Meta: reverseproxy.ServiceMeta{ - CertificateIssuedAt: time.Now(), - Status: "active", - }, - SessionPrivateKey: "private-key", - SessionPublicKey: "public-key", - } - - updated := &reverseproxy.Service{ - Domain: "updated.com", - } - - mgr.preserveServiceMetadata(updated, existing) - - assert.Equal(t, existing.Meta, updated.Meta) - assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey) - assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey) -} diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go new file mode 100644 index 000000000..aa7cd8630 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -0,0 +1,40 @@ +package proxy + +//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "time" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// Manager defines the interface for proxy operations +type Manager interface { + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error + Disconnect(ctx context.Context, proxyID string) error + Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusters(ctx context.Context) ([]Cluster, error) + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + CleanupStale(ctx context.Context, inactivityDuration time.Duration) error +} + +// OIDCValidationConfig contains the OIDC configuration needed for token validation. +type OIDCValidationConfig struct { + Issuer string + Audiences []string + KeysLocation string + MaxTokenAgeSeconds int64 +} + +// Controller is responsible for managing proxy clusters and routing service updates. +type Controller interface { + SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) + GetOIDCValidationConfig() OIDCValidationConfig + RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error + UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error + GetProxiesForCluster(clusterAddr string) []string +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go new file mode 100644 index 000000000..e5b3e9886 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -0,0 +1,88 @@ +package manager + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC. +type GRPCController struct { + proxyGRPCServer *nbgrpc.ProxyServiceServer + // Map of cluster address -> set of proxy IDs + clusterProxies sync.Map + metrics *metrics +} + +// NewGRPCController creates a new GRPCController. +func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) { + m, err := newMetrics(meter) + if err != nil { + return nil, err + } + + return &GRPCController{ + proxyGRPCServer: proxyGRPCServer, + metrics: m, + }, nil +} + +// SendServiceUpdateToCluster sends a service update to a specific proxy cluster. +func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) { + c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr) + c.metrics.IncrementServiceUpdateSendCount(clusterAddr) +} + +// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server. +func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return c.proxyGRPCServer.GetOIDCValidationConfig() +} + +// RegisterProxyToCluster registers a proxy to a specific cluster for routing. +func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr) + + c.metrics.IncrementProxyConnectionCount(clusterAddr) + + return nil +} + +// UnregisterProxyFromCluster removes a proxy from a cluster. +func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok { + proxySet.(*sync.Map).Delete(proxyID) + log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr) + + c.metrics.DecrementProxyConnectionCount(clusterAddr) + } + return nil +} + +// GetProxiesForCluster returns all proxy IDs registered for a specific cluster. +func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { + proxySet, ok := c.clusterProxies.Load(clusterAddr) + if !ok { + return nil + } + + var proxies []string + proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { + proxies = append(proxies, key.(string)) + return true + }) + return proxies +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go new file mode 100644 index 000000000..d13334e83 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -0,0 +1,155 @@ +package manager + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +// store defines the interface for proxy persistence operations +type store interface { + SaveProxy(ctx context.Context, p *proxy.Proxy) error + UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error +} + +// Manager handles all proxy operations +type Manager struct { + store store + metrics *metrics +} + +// NewManager creates a new proxy Manager +func NewManager(store store, meter metric.Meter) (*Manager, error) { + m, err := newMetrics(meter) + if err != nil { + return nil, err + } + + return &Manager{ + store: store, + metrics: m, + }, nil +} + +// Connect registers a new proxy connection in the database. +// capabilities may be nil for old proxies that do not report them. +func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { + now := time.Now() + var caps proxy.Capabilities + if capabilities != nil { + caps = *capabilities + } + p := &proxy.Proxy{ + ID: proxyID, + ClusterAddress: clusterAddress, + IPAddress: ipAddress, + LastSeen: now, + ConnectedAt: &now, + Status: "connected", + Capabilities: caps, + } + + if err := m.store.SaveProxy(ctx, p); err != nil { + log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) + return err + } + + log.WithContext(ctx).WithFields(log.Fields{ + "proxyID": proxyID, + "clusterAddress": clusterAddress, + "ipAddress": ipAddress, + }).Info("proxy connected") + + return nil +} + +// Disconnect marks a proxy as disconnected in the database +func (m Manager) Disconnect(ctx context.Context, proxyID string) error { + now := time.Now() + p := &proxy.Proxy{ + ID: proxyID, + Status: "disconnected", + DisconnectedAt: &now, + LastSeen: now, + } + + if err := m.store.SaveProxy(ctx, p); err != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) + return err + } + + log.WithContext(ctx).WithFields(log.Fields{ + "proxyID": proxyID, + }).Info("proxy disconnected") + + return nil +} + +// Heartbeat updates the proxy's last seen timestamp +func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { + log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) + return err + } + + log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) + m.metrics.IncrementProxyHeartbeatCount() + return nil +} + +// GetActiveClusterAddresses returns all unique cluster addresses for active proxies +func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) + return nil, err + } + return addresses, nil +} + +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { + clusters, err := m.store.GetActiveProxyClusters(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) + return nil, err + } + return clusters, nil +} + +// ClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr) +} + +// ClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterRequireSubdomain(ctx, clusterAddr) +} + +// ClusterSupportsCrowdSec returns whether all active proxies in the cluster +// have CrowdSec configured (unanimous). Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr) +} + +// CleanupStale removes proxies that haven't sent heartbeat in the specified duration +func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { + if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { + log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) + return err + } + return nil +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/metrics.go b/management/internals/modules/reverseproxy/proxy/manager/metrics.go new file mode 100644 index 000000000..2b402cead --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/metrics.go @@ -0,0 +1,74 @@ +package manager + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +type metrics struct { + proxyConnectionCount metric.Int64UpDownCounter + serviceUpdateSendCount metric.Int64Counter + proxyHeartbeatCount metric.Int64Counter +} + +func newMetrics(meter metric.Meter) (*metrics, error) { + proxyConnectionCount, err := meter.Int64UpDownCounter( + "management_proxy_connection_count", + metric.WithDescription("Number of active proxy connections"), + metric.WithUnit("{connection}"), + ) + if err != nil { + return nil, err + } + + serviceUpdateSendCount, err := meter.Int64Counter( + "management_proxy_service_update_send_count", + metric.WithDescription("Total number of service updates sent to proxies"), + metric.WithUnit("{update}"), + ) + if err != nil { + return nil, err + } + + proxyHeartbeatCount, err := meter.Int64Counter( + "management_proxy_heartbeat_count", + metric.WithDescription("Total number of proxy heartbeats received"), + metric.WithUnit("{heartbeat}"), + ) + if err != nil { + return nil, err + } + + return &metrics{ + proxyConnectionCount: proxyConnectionCount, + serviceUpdateSendCount: serviceUpdateSendCount, + proxyHeartbeatCount: proxyHeartbeatCount, + }, nil +} + +func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) { + m.proxyConnectionCount.Add(context.Background(), 1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) { + m.proxyConnectionCount.Add(context.Background(), -1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) { + m.serviceUpdateSendCount.Add(context.Background(), 1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) IncrementProxyHeartbeatCount() { + m.proxyHeartbeatCount.Add(context.Background(), 1) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go new file mode 100644 index 000000000..282ca0ba5 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -0,0 +1,256 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./manager.go + +// Package proxy is a generated GoMock package. +package proxy + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + proto "github.com/netbirdio/netbird/shared/management/proto" +) + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// CleanupStale mocks base method. +func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupStale indicates an expected call of CleanupStale. +func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration) +} + +// ClusterSupportsCustomPorts mocks base method. +func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// ClusterRequireSubdomain mocks base method. +func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. +func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr) +} + +// ClusterSupportsCrowdSec mocks base method. +func (m *MockManager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCrowdSec indicates an expected call of ClusterSupportsCrowdSec. +func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr) +} + +// Connect mocks base method. +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) + ret0, _ := ret[0].(error) + return ret0 +} + +// Connect indicates an expected call of Connect. +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) +} + +// Disconnect mocks base method. +func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Disconnect indicates an expected call of Disconnect. +func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) +} + +// GetActiveClusterAddresses mocks base method. +func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses. +func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) +} + +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusters", ctx) + ret0, _ := ret[0].([]Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) +} + +// Heartbeat mocks base method. +func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// Heartbeat indicates an expected call of Heartbeat. +func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) +} + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// GetOIDCValidationConfig mocks base method. +func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOIDCValidationConfig") + ret0, _ := ret[0].(OIDCValidationConfig) + return ret0 +} + +// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig. +func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) +} + +// GetProxiesForCluster mocks base method. +func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr) + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetProxiesForCluster indicates an expected call of GetProxiesForCluster. +func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr) +} + +// RegisterProxyToCluster mocks base method. +func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster. +func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID) +} + +// SendServiceUpdateToCluster mocks base method. +func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr) +} + +// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster. +func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr) +} + +// UnregisterProxyFromCluster mocks base method. +func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster. +func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID) +} diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go new file mode 100644 index 000000000..339c82446 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -0,0 +1,40 @@ +package proxy + +import "time" + +// Capabilities describes what a proxy can handle, as reported via gRPC. +// Nil fields mean the proxy never reported this capability. +type Capabilities struct { + // SupportsCustomPorts indicates whether this proxy can bind arbitrary + // ports for TCP/UDP services. TLS uses SNI routing and is not gated. + SupportsCustomPorts *bool + // RequireSubdomain indicates whether a subdomain label is required in + // front of the cluster domain. + RequireSubdomain *bool + // SupportsCrowdsec indicates whether this proxy has CrowdSec configured. + SupportsCrowdsec *bool +} + +// Proxy represents a reverse proxy instance +type Proxy struct { + ID string `gorm:"primaryKey;type:varchar(255)"` + ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` + IPAddress string `gorm:"type:varchar(45)"` + LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` + ConnectedAt *time.Time + DisconnectedAt *time.Time + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Capabilities Capabilities `gorm:"embedded"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func (Proxy) TableName() string { + return "proxies" +} + +// Cluster represents a group of proxy nodes serving the same address. +type Cluster struct { + Address string + ConnectedProxies int +} diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/reverseproxy.go deleted file mode 100644 index 0cbbe450b..000000000 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ /dev/null @@ -1,463 +0,0 @@ -package reverseproxy - -import ( - "errors" - "fmt" - "net" - "net/url" - "strconv" - "time" - - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/shared/hash/argon2id" - "github.com/netbirdio/netbird/util/crypt" - - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/proto" -) - -type Operation string - -const ( - Create Operation = "create" - Update Operation = "update" - Delete Operation = "delete" -) - -type ProxyStatus string - -const ( - StatusPending ProxyStatus = "pending" - StatusActive ProxyStatus = "active" - StatusTunnelNotCreated ProxyStatus = "tunnel_not_created" - StatusCertificatePending ProxyStatus = "certificate_pending" - StatusCertificateFailed ProxyStatus = "certificate_failed" - StatusError ProxyStatus = "error" - - TargetTypePeer = "peer" - TargetTypeHost = "host" - TargetTypeDomain = "domain" - TargetTypeSubnet = "subnet" -) - -type Target struct { - ID uint `gorm:"primaryKey" json:"-"` - AccountID string `gorm:"index:idx_target_account;not null" json:"-"` - ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` - Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored - Port int `gorm:"index:idx_target_port" json:"port"` - Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` - TargetId string `gorm:"index:idx_target_id" json:"target_id"` - TargetType string `gorm:"index:idx_target_type" json:"target_type"` - Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` -} - -type PasswordAuthConfig struct { - Enabled bool `json:"enabled"` - Password string `json:"password"` -} - -type PINAuthConfig struct { - Enabled bool `json:"enabled"` - Pin string `json:"pin"` -} - -type BearerAuthConfig struct { - Enabled bool `json:"enabled"` - DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"` -} - -type AuthConfig struct { - PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"` - PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"` - BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"` -} - -func (a *AuthConfig) HashSecrets() error { - if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" { - hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password) - if err != nil { - return fmt.Errorf("hash password: %w", err) - } - a.PasswordAuth.Password = hashedPassword - } - - if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" { - hashedPin, err := argon2id.Hash(a.PinAuth.Pin) - if err != nil { - return fmt.Errorf("hash pin: %w", err) - } - a.PinAuth.Pin = hashedPin - } - - return nil -} - -func (a *AuthConfig) ClearSecrets() { - if a.PasswordAuth != nil { - a.PasswordAuth.Password = "" - } - if a.PinAuth != nil { - a.PinAuth.Pin = "" - } -} - -type OIDCValidationConfig struct { - Issuer string - Audiences []string - KeysLocation string - MaxTokenAgeSeconds int64 -} - -type ServiceMeta struct { - CreatedAt time.Time - CertificateIssuedAt time.Time - Status string -} - -type Service struct { - ID string `gorm:"primaryKey"` - AccountID string `gorm:"index"` - Name string - Domain string `gorm:"index"` - ProxyCluster string `gorm:"index"` - Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` - Enabled bool - PassHostHeader bool - RewriteRedirects bool - Auth AuthConfig `gorm:"serializer:json"` - Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"` - SessionPrivateKey string `gorm:"column:session_private_key"` - SessionPublicKey string `gorm:"column:session_public_key"` -} - -func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { - for _, target := range targets { - target.AccountID = accountID - } - - s := &Service{ - AccountID: accountID, - Name: name, - Domain: domain, - ProxyCluster: proxyCluster, - Targets: targets, - Enabled: enabled, - } - s.InitNewRecord() - return s -} - -// InitNewRecord generates a new unique ID and resets metadata for a newly created -// Service record. This overwrites any existing ID and Meta fields and should -// only be called during initial creation, not for updates. -func (s *Service) InitNewRecord() { - s.ID = xid.New().String() - s.Meta = ServiceMeta{ - CreatedAt: time.Now(), - Status: string(StatusPending), - } -} - -func (s *Service) ToAPIResponse() *api.Service { - s.Auth.ClearSecrets() - - authConfig := api.ServiceAuthConfig{} - - if s.Auth.PasswordAuth != nil { - authConfig.PasswordAuth = &api.PasswordAuthConfig{ - Enabled: s.Auth.PasswordAuth.Enabled, - Password: s.Auth.PasswordAuth.Password, - } - } - - if s.Auth.PinAuth != nil { - authConfig.PinAuth = &api.PINAuthConfig{ - Enabled: s.Auth.PinAuth.Enabled, - Pin: s.Auth.PinAuth.Pin, - } - } - - if s.Auth.BearerAuth != nil { - authConfig.BearerAuth = &api.BearerAuthConfig{ - Enabled: s.Auth.BearerAuth.Enabled, - DistributionGroups: &s.Auth.BearerAuth.DistributionGroups, - } - } - - // Convert internal targets to API targets - apiTargets := make([]api.ServiceTarget, 0, len(s.Targets)) - for _, target := range s.Targets { - apiTargets = append(apiTargets, api.ServiceTarget{ - Path: target.Path, - Host: &target.Host, - Port: target.Port, - Protocol: api.ServiceTargetProtocol(target.Protocol), - TargetId: target.TargetId, - TargetType: api.ServiceTargetTargetType(target.TargetType), - Enabled: target.Enabled, - }) - } - - meta := api.ServiceMeta{ - CreatedAt: s.Meta.CreatedAt, - Status: api.ServiceMetaStatus(s.Meta.Status), - } - - if !s.Meta.CertificateIssuedAt.IsZero() { - meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt - } - - resp := &api.Service{ - Id: s.ID, - Name: s.Name, - Domain: s.Domain, - Targets: apiTargets, - Enabled: s.Enabled, - PassHostHeader: &s.PassHostHeader, - RewriteRedirects: &s.RewriteRedirects, - Auth: authConfig, - Meta: meta, - } - - if s.ProxyCluster != "" { - resp.ProxyCluster = &s.ProxyCluster - } - - return resp -} - -func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping { - pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) - for _, target := range s.Targets { - if !target.Enabled { - continue - } - - // TODO: Make path prefix stripping configurable per-target. - // Currently the matching prefix is baked into the target URL path, - // so the proxy strips-then-re-adds it (effectively a no-op). - targetURL := url.URL{ - Scheme: target.Protocol, - Host: target.Host, - Path: "/", // TODO: support service path - } - if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port)) - } - - path := "/" - if target.Path != nil { - path = *target.Path - } - pathMappings = append(pathMappings, &proto.PathMapping{ - Path: path, - Target: targetURL.String(), - }) - } - - auth := &proto.Authentication{ - SessionKey: s.SessionPublicKey, - MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()), - } - - if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled { - auth.Password = true - } - - if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled { - auth.Pin = true - } - - if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled { - auth.Oidc = true - } - - return &proto.ProxyMapping{ - Type: operationToProtoType(operation), - Id: s.ID, - Domain: s.Domain, - Path: pathMappings, - AuthToken: authToken, - Auth: auth, - AccountId: s.AccountID, - PassHostHeader: s.PassHostHeader, - RewriteRedirects: s.RewriteRedirects, - } -} - -func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { - switch op { - case Create: - return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED - case Update: - return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED - case Delete: - return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED - default: - log.Fatalf("unknown operation type: %v", op) - return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED - } -} - -// isDefaultPort reports whether port is the standard default for the given scheme -// (443 for https, 80 for http). -func isDefaultPort(scheme string, port int) bool { - return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) -} - -func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) { - s.Name = req.Name - s.Domain = req.Domain - s.AccountID = accountID - - targets := make([]*Target, 0, len(req.Targets)) - for _, apiTarget := range req.Targets { - target := &Target{ - AccountID: accountID, - Path: apiTarget.Path, - Port: apiTarget.Port, - Protocol: string(apiTarget.Protocol), - TargetId: apiTarget.TargetId, - TargetType: string(apiTarget.TargetType), - Enabled: apiTarget.Enabled, - } - if apiTarget.Host != nil { - target.Host = *apiTarget.Host - } - targets = append(targets, target) - } - s.Targets = targets - - s.Enabled = req.Enabled - - if req.PassHostHeader != nil { - s.PassHostHeader = *req.PassHostHeader - } - - if req.RewriteRedirects != nil { - s.RewriteRedirects = *req.RewriteRedirects - } - - if req.Auth.PasswordAuth != nil { - s.Auth.PasswordAuth = &PasswordAuthConfig{ - Enabled: req.Auth.PasswordAuth.Enabled, - Password: req.Auth.PasswordAuth.Password, - } - } - - if req.Auth.PinAuth != nil { - s.Auth.PinAuth = &PINAuthConfig{ - Enabled: req.Auth.PinAuth.Enabled, - Pin: req.Auth.PinAuth.Pin, - } - } - - if req.Auth.BearerAuth != nil { - bearerAuth := &BearerAuthConfig{ - Enabled: req.Auth.BearerAuth.Enabled, - } - if req.Auth.BearerAuth.DistributionGroups != nil { - bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups - } - s.Auth.BearerAuth = bearerAuth - } -} - -func (s *Service) Validate() error { - if s.Name == "" { - return errors.New("service name is required") - } - if len(s.Name) > 255 { - return errors.New("service name exceeds maximum length of 255 characters") - } - - if s.Domain == "" { - return errors.New("service domain is required") - } - - if len(s.Targets) == 0 { - return errors.New("at least one target is required") - } - - for i, target := range s.Targets { - switch target.TargetType { - case TargetTypePeer, TargetTypeHost, TargetTypeDomain: - // host field will be ignored - case TargetTypeSubnet: - if target.Host == "" { - return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType) - } - default: - return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType) - } - if target.TargetId == "" { - return fmt.Errorf("target %d has empty target_id", i) - } - } - - return nil -} - -func (s *Service) EventMeta() map[string]any { - return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster} -} - -func (s *Service) Copy() *Service { - targets := make([]*Target, len(s.Targets)) - for i, target := range s.Targets { - targetCopy := *target - targets[i] = &targetCopy - } - - return &Service{ - ID: s.ID, - AccountID: s.AccountID, - Name: s.Name, - Domain: s.Domain, - ProxyCluster: s.ProxyCluster, - Targets: targets, - Enabled: s.Enabled, - PassHostHeader: s.PassHostHeader, - RewriteRedirects: s.RewriteRedirects, - Auth: s.Auth, - Meta: s.Meta, - SessionPrivateKey: s.SessionPrivateKey, - SessionPublicKey: s.SessionPublicKey, - } -} - -func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { - if enc == nil { - return nil - } - - if s.SessionPrivateKey != "" { - var err error - s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey) - if err != nil { - return err - } - } - - return nil -} - -func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { - if enc == nil { - return nil - } - - if s.SessionPrivateKey != "" { - var err error - s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey) - if err != nil { - return err - } - } - - return nil -} diff --git a/management/internals/modules/reverseproxy/reverseproxy_test.go b/management/internals/modules/reverseproxy/reverseproxy_test.go deleted file mode 100644 index 546e80b31..000000000 --- a/management/internals/modules/reverseproxy/reverseproxy_test.go +++ /dev/null @@ -1,405 +0,0 @@ -package reverseproxy - -import ( - "errors" - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/shared/hash/argon2id" - "github.com/netbirdio/netbird/shared/management/proto" -) - -func validProxy() *Service { - return &Service{ - Name: "test", - Domain: "example.com", - Targets: []*Target{ - {TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true}, - }, - } -} - -func TestValidate_Valid(t *testing.T) { - require.NoError(t, validProxy().Validate()) -} - -func TestValidate_EmptyName(t *testing.T) { - rp := validProxy() - rp.Name = "" - assert.ErrorContains(t, rp.Validate(), "name is required") -} - -func TestValidate_EmptyDomain(t *testing.T) { - rp := validProxy() - rp.Domain = "" - assert.ErrorContains(t, rp.Validate(), "domain is required") -} - -func TestValidate_NoTargets(t *testing.T) { - rp := validProxy() - rp.Targets = nil - assert.ErrorContains(t, rp.Validate(), "at least one target") -} - -func TestValidate_EmptyTargetId(t *testing.T) { - rp := validProxy() - rp.Targets[0].TargetId = "" - assert.ErrorContains(t, rp.Validate(), "empty target_id") -} - -func TestValidate_InvalidTargetType(t *testing.T) { - rp := validProxy() - rp.Targets[0].TargetType = "invalid" - assert.ErrorContains(t, rp.Validate(), "invalid target_type") -} - -func TestValidate_ResourceTarget(t *testing.T) { - rp := validProxy() - rp.Targets = append(rp.Targets, &Target{ - TargetId: "resource-1", - TargetType: TargetTypeHost, - Host: "example.org", - Port: 443, - Protocol: "https", - Enabled: true, - }) - require.NoError(t, rp.Validate()) -} - -func TestValidate_MultipleTargetsOneInvalid(t *testing.T) { - rp := validProxy() - rp.Targets = append(rp.Targets, &Target{ - TargetId: "", - TargetType: TargetTypePeer, - Host: "10.0.0.2", - Port: 80, - Protocol: "http", - Enabled: true, - }) - err := rp.Validate() - require.Error(t, err) - assert.Contains(t, err.Error(), "target 1") - assert.Contains(t, err.Error(), "empty target_id") -} - -func TestIsDefaultPort(t *testing.T) { - tests := []struct { - scheme string - port int - want bool - }{ - {"http", 80, true}, - {"https", 443, true}, - {"http", 443, false}, - {"https", 80, false}, - {"http", 8080, false}, - {"https", 8443, false}, - {"http", 0, false}, - {"https", 0, false}, - } - for _, tt := range tests { - t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) { - assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port)) - }) - } -} - -func TestToProtoMapping_PortInTargetURL(t *testing.T) { - oidcConfig := OIDCValidationConfig{} - - tests := []struct { - name string - protocol string - host string - port int - wantTarget string - }{ - { - name: "http with default port 80 omits port", - protocol: "http", - host: "10.0.0.1", - port: 80, - wantTarget: "http://10.0.0.1/", - }, - { - name: "https with default port 443 omits port", - protocol: "https", - host: "10.0.0.1", - port: 443, - wantTarget: "https://10.0.0.1/", - }, - { - name: "port 0 omits port", - protocol: "http", - host: "10.0.0.1", - port: 0, - wantTarget: "http://10.0.0.1/", - }, - { - name: "non-default port is included", - protocol: "http", - host: "10.0.0.1", - port: 8080, - wantTarget: "http://10.0.0.1:8080/", - }, - { - name: "https with non-default port is included", - protocol: "https", - host: "10.0.0.1", - port: 8443, - wantTarget: "https://10.0.0.1:8443/", - }, - { - name: "http port 443 is included", - protocol: "http", - host: "10.0.0.1", - port: 443, - wantTarget: "http://10.0.0.1:443/", - }, - { - name: "https port 80 is included", - protocol: "https", - host: "10.0.0.1", - port: 80, - wantTarget: "https://10.0.0.1:80/", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rp := &Service{ - ID: "test-id", - AccountID: "acc-1", - Domain: "example.com", - Targets: []*Target{ - { - TargetId: "peer-1", - TargetType: TargetTypePeer, - Host: tt.host, - Port: tt.port, - Protocol: tt.protocol, - Enabled: true, - }, - }, - } - pm := rp.ToProtoMapping(Create, "token", oidcConfig) - require.Len(t, pm.Path, 1, "should have one path mapping") - assert.Equal(t, tt.wantTarget, pm.Path[0].Target) - }) - } -} - -func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) { - rp := &Service{ - ID: "test-id", - AccountID: "acc-1", - Domain: "example.com", - Targets: []*Target{ - {TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false}, - {TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true}, - }, - } - pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{}) - require.Len(t, pm.Path, 1) - assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target) -} - -func TestToProtoMapping_OperationTypes(t *testing.T) { - rp := validProxy() - tests := []struct { - op Operation - want proto.ProxyMappingUpdateType - }{ - {Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED}, - {Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED}, - {Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED}, - } - for _, tt := range tests { - t.Run(string(tt.op), func(t *testing.T) { - pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{}) - assert.Equal(t, tt.want, pm.Type) - }) - } -} - -func TestAuthConfig_HashSecrets(t *testing.T) { - tests := []struct { - name string - config *AuthConfig - wantErr bool - validate func(*testing.T, *AuthConfig) - }{ - { - name: "hash password successfully", - config: &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: true, - Password: "testPassword123", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") { - t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password) - } - // Verify the hash can be verified - if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil { - t.Errorf("Hash verification failed: %v", err) - } - }, - }, - { - name: "hash PIN successfully", - config: &AuthConfig{ - PinAuth: &PINAuthConfig{ - Enabled: true, - Pin: "123456", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { - t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin) - } - // Verify the hash can be verified - if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil { - t.Errorf("Hash verification failed: %v", err) - } - }, - }, - { - name: "hash both password and PIN", - config: &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: true, - Password: "password", - }, - PinAuth: &PINAuthConfig{ - Enabled: true, - Pin: "9999", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") { - t.Errorf("Password not hashed with argon2id") - } - if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { - t.Errorf("PIN not hashed with argon2id") - } - if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil { - t.Errorf("Password hash verification failed: %v", err) - } - if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil { - t.Errorf("PIN hash verification failed: %v", err) - } - }, - }, - { - name: "skip disabled password auth", - config: &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: false, - Password: "password", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if config.PasswordAuth.Password != "password" { - t.Errorf("Disabled password auth should not be hashed") - } - }, - }, - { - name: "skip empty password", - config: &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: true, - Password: "", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if config.PasswordAuth.Password != "" { - t.Errorf("Empty password should remain empty") - } - }, - }, - { - name: "skip nil password auth", - config: &AuthConfig{ - PasswordAuth: nil, - PinAuth: &PINAuthConfig{ - Enabled: true, - Pin: "1234", - }, - }, - wantErr: false, - validate: func(t *testing.T, config *AuthConfig) { - if config.PasswordAuth != nil { - t.Errorf("PasswordAuth should remain nil") - } - if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { - t.Errorf("PIN should still be hashed") - } - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.config.HashSecrets() - if (err != nil) != tt.wantErr { - t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.validate != nil { - tt.validate(t, tt.config) - } - }) - } -} - -func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) { - config := &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: true, - Password: "correctPassword", - }, - } - - if err := config.HashSecrets(); err != nil { - t.Fatalf("HashSecrets() error = %v", err) - } - - // Verify with wrong password should fail - err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password) - if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) { - t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err) - } -} - -func TestAuthConfig_ClearSecrets(t *testing.T) { - config := &AuthConfig{ - PasswordAuth: &PasswordAuthConfig{ - Enabled: true, - Password: "hashedPassword", - }, - PinAuth: &PINAuthConfig{ - Enabled: true, - Pin: "hashedPin", - }, - } - - config.ClearSecrets() - - if config.PasswordAuth.Password != "" { - t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password) - } - if config.PinAuth.Pin != "" { - t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin) - } -} diff --git a/management/internals/modules/reverseproxy/interface.go b/management/internals/modules/reverseproxy/service/interface.go similarity index 59% rename from management/internals/modules/reverseproxy/interface.go rename to management/internals/modules/reverseproxy/service/interface.go index 7614b3ce5..a49cbea35 100644 --- a/management/internals/modules/reverseproxy/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -1,23 +1,31 @@ -package reverseproxy +package service -//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod import ( "context" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) type Manager interface { + GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) DeleteService(ctx context.Context, accountID, userID, serviceID string) error + DeleteAllServices(ctx context.Context, accountID, userID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error - SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error + SetStatus(ctx context.Context, accountID, serviceID string, status Status) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error ReloadService(ctx context.Context, accountID, serviceID string) error GetGlobalServices(ctx context.Context) ([]*Service, error) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) + CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) + RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + StartExposeReaper(ctx context.Context) } diff --git a/management/internals/modules/reverseproxy/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go similarity index 69% rename from management/internals/modules/reverseproxy/interface_mock.go rename to management/internals/modules/reverseproxy/service/interface_mock.go index d5f38c38a..cc5ccbb8e 100644 --- a/management/internals/modules/reverseproxy/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -1,14 +1,15 @@ // Code generated by MockGen. DO NOT EDIT. // Source: ./interface.go -// Package reverseproxy is a generated GoMock package. -package reverseproxy +// Package service is a generated GoMock package. +package service import ( context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" + proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) // MockManager is a mock of Manager interface. @@ -49,6 +50,35 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service) } +// CreateServiceFromPeer mocks base method. +func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req) + ret0, _ := ret[0].(*ExposeServiceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer. +func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req) +} + +// DeleteAllServices mocks base method. +func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllServices indicates an expected call of DeleteAllServices. +func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -78,6 +108,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID) } +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID) +} + // GetAllServices mocks base method. func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) { m.ctrl.T.Helper() @@ -181,6 +226,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID) } +// RenewServiceFromPeer mocks base method. +func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer. +func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID) +} + // SetCertificateIssuedAt mocks base method. func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { m.ctrl.T.Helper() @@ -196,7 +255,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic } // SetStatus mocks base method. -func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error { +func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status) ret0, _ := ret[0].(error) @@ -209,6 +268,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status) } +// StartExposeReaper mocks base method. +func (m *MockManager) StartExposeReaper(ctx context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartExposeReaper", ctx) +} + +// StartExposeReaper indicates an expected call of StartExposeReaper. +func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx) +} + +// StopServiceFromPeer mocks base method. +func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopServiceFromPeer indicates an expected call of StopServiceFromPeer. +func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID) +} + // UpdateService mocks base method. func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go similarity index 75% rename from management/internals/modules/reverseproxy/manager/api.go rename to management/internals/modules/reverseproxy/service/manager/api.go index 9117ecd38..cd81efa88 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -6,24 +6,27 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) type handler struct { - manager reverseproxy.Manager + manager rpservice.Manager + permissionsManager permissions.Manager } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) { h := &handler{ - manager: manager, + manager: manager, + permissionsManager: permissionsManager, } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() @@ -31,6 +34,7 @@ func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager accesslogsmanager.RegisterEndpoints(router, accessLogsManager) + router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -72,8 +76,11 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { return } - service := new(reverseproxy.Service) - service.FromAPIRequest(&req, userAuth.AccountId) + service := new(rpservice.Service) + if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } if err = service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) @@ -130,9 +137,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { return } - service := new(reverseproxy.Service) + service := new(rpservice.Service) service.ID = serviceID - service.FromAPIRequest(&req, userAuth.AccountId) + if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } if err = service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) @@ -168,3 +178,27 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } + +func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + apiClusters := make([]api.ProxyCluster, 0, len(clusters)) + for _, c := range clusters { + apiClusters = append(apiClusters, api.ProxyCluster{ + Address: c.Address, + ConnectedProxies: c.ConnectedProxies, + }) + } + + util.WriteJSONObject(r.Context(), w, apiClusters) +} diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go new file mode 100644 index 000000000..911add3bb --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go @@ -0,0 +1,65 @@ +package manager + +import ( + "context" + "math/rand/v2" + "time" + + "github.com/netbirdio/netbird/shared/management/status" + log "github.com/sirupsen/logrus" +) + +const ( + exposeTTL = 90 * time.Second + exposeReapInterval = 30 * time.Second + maxExposesPerPeer = 10 + exposeReapBatch = 100 +) + +type exposeReaper struct { + manager *Manager +} + +// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB. +func (r *exposeReaper) StartExposeReaper(ctx context.Context) { + go func() { + // start with a random delay + rn := rand.IntN(10) + time.Sleep(time.Duration(rn) * time.Second) + + ticker := time.NewTicker(exposeReapInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.reapExpiredExposes(ctx) + } + } + }() +} + +func (r *exposeReaper) reapExpiredExposes(ctx context.Context) { + expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch) + if err != nil { + log.Errorf("failed to get expired ephemeral services: %v", err) + return + } + + for _, svc := range expired { + log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain) + + err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID) + if err == nil { + continue + } + + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + log.Debugf("service %s was already deleted by another instance", svc.Domain) + } else { + log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err) + } + } +} diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go new file mode 100644 index 000000000..6ff8343b9 --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -0,0 +1,221 @@ +package manager + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" +) + +func TestReapExpiredExposes(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + }) + require.NoError(t, err) + + // Manually expire the service by backdating meta_last_renewed_at + expireEphemeralService(t, testStore, testAccountID, resp.Domain) + + // Create a non-expired service + resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8081, + Mode: "http", + }) + require.NoError(t, err) + + mgr.exposeReaper.reapExpiredExposes(ctx) + + // Expired service should be deleted + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) + require.Error(t, err, "expired service should be deleted") + + // Non-expired service should remain + _, err = testStore.GetServiceByDomain(ctx, resp2.Domain) + require.NoError(t, err, "active service should remain") +} + +func TestReapAlreadyDeletedService(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + }) + require.NoError(t, err) + + expireEphemeralService(t, testStore, testAccountID, resp.Domain) + + // Delete the service before reaping + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + // Reaping should handle the already-deleted service gracefully + mgr.exposeReaper.reapExpiredExposes(ctx) +} + +func TestConcurrentReapAndRenew(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + ctx := context.Background() + + for i := range 5 { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: uint16(8080 + i), + Mode: "http", + }) + require.NoError(t, err) + } + + // Expire all services + services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + for _, svc := range services { + if svc.Source == rpservice.SourceEphemeral { + expireEphemeralService(t, testStore, testAccountID, svc.Domain) + } + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + mgr.exposeReaper.reapExpiredExposes(ctx) + }() + go func() { + defer wg.Done() + _, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + }() + wg.Wait() + + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "all expired services should be reaped") +} + +func TestRenewEphemeralService(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + t.Run("renew succeeds for active service", func(t *testing.T) { + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8082, + Mode: "http", + }) + require.NoError(t, err) + + svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, lookupErr) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) + require.NoError(t, err) + }) + + t.Run("renew fails for nonexistent domain", func(t *testing.T) { + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "no active expose session") + }) +} + +func TestCountAndExistsEphemeralServices(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8083, + Mode: "http", + }) + require.NoError(t, err) + + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + assert.True(t, exists, "service should exist") + + exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain") + require.NoError(t, err) + assert.False(t, exists, "non-existent service should not exist") +} + +func TestMaxExposesPerPeerEnforced(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + for i := range maxExposesPerPeer { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: uint16(8090 + i), + Mode: "http", + }) + require.NoError(t, err, "expose %d should succeed", i) + } + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 9999, + Mode: "http", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "maximum number of active expose sessions") +} + +func TestReapSkipsRenewedService(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8086, + Mode: "http", + }) + require.NoError(t, err) + + // Expire the service + expireEphemeralService(t, testStore, testAccountID, resp.Domain) + + // Renew it before the reaper runs + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) + require.NoError(t, err) + + // Reaper should skip it because the re-check sees a fresh timestamp + mgr.exposeReaper.reapExpiredExposes(ctx) + + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err, "renewed service should survive reaping") +} + +// resolveServiceIDByDomain looks up a service ID by domain in tests. +func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), domain) + require.NoError(t, err) + return svc.ID +} + +// expireEphemeralService backdates meta_last_renewed_at to force expiration. +func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), domain) + require.NoError(t, err) + + expired := time.Now().Add(-2 * exposeTTL) + svc.Meta.LastRenewedAt = &expired + err = s.UpdateService(context.Background(), svc) + require.NoError(t, err) +} diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go new file mode 100644 index 000000000..28461641d --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -0,0 +1,587 @@ +package manager + +import ( + "context" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +const testCluster = "test-cluster" + +func boolPtr(v bool) *bool { return &v } + +// setupL4Test creates a manager with a mock proxy controller for L4 port tests. +func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) { + t.Helper() + + ctrl := gomock.NewController(t) + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + CreatedBy: testUserID, + Settings: &types.Settings{ + PeerExposeEnabled: true, + PeerExposeGroups: []string{testGroupID}, + }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, + Peers: map[string]*nbpeer.Peer{ + testPeerID: { + ID: testPeerID, + AccountID: testAccountID, + Key: "test-key", + DNSLabel: "test-peer", + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, + }, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + AccountID: testAccountID, + Name: "Expose Group", + }, + }, + }) + require.NoError(t, err) + + err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) + require.NoError(t, err) + + mockCtrl := proxy.NewMockController(ctrl) + mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes() + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + mockCaps.EXPECT().ClusterSupportsCrowdSec(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + + accountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) + }, + } + + mgr := &Manager{ + store: testStore, + accountManager: accountMgr, + permissionsManager: permissions.NewManager(testStore), + proxyController: mockCtrl, + capabilities: mockCaps, + clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + + return mgr, testStore, mockCtrl +} + +// seedService creates a service directly in the store for test setup. +func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service { + t.Helper() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: name, + Mode: protocol, + Domain: domain, + ProxyCluster: cluster, + ListenPort: port, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + err := s.CreateService(context.Background(), svc) + require.NoError(t, err) + return svc +} + +func TestPortConflict_TCPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-tcp", + Mode: "tcp", + Domain: "conflicting-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP+TCP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_UDPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-udp", + Mode: "udp", + Domain: "conflicting-udp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "UDP+UDP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tls", + Mode: "tls", + Domain: "app2.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)") +} + +func TestPortConflict_TLSSamePortSameDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "duplicate-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TLS+TLS on same domain should be rejected") + assert.Contains(t, err.Error(), "domain already taken") +} + +func TestPortConflict_TLSAndTCPSamePort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)") +} + +func TestAutoAssign_TCPNoListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax) + assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set") +} + +func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it") + assert.Contains(t, err.Error(), "custom ports") +} + +func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability") + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS") +} + +func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden") + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range", svc.ListenPort) + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_AvoidsExistingPorts(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existingPort := uint16(20000) + seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: "auto-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing") + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported") + assert.False(t, svc.PortAutoAssigned) +} + +func TestUpdate_PreservesExistingListenPort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc-renamed", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0") +} + +func TestUpdate_AllowsPortChange(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 54321, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied") +} + +func TestCreateServiceFromPeer_TCP(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + assert.NotEmpty(t, resp.ServiceName) + assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain") + assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports") + assert.Contains(t, resp.ServiceURL, "tcp://") +} + +func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + ListenPort: 15432, + }) + require.NoError(t, err) + + assert.False(t, resp.PortAutoAssigned) + assert.Contains(t, resp.ServiceURL, ":15432") +} + +func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + // When no explicit listen port, defaults to target port + assert.Contains(t, resp.ServiceURL, ":5432") + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TLS(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 443, + Mode: "tls", + }) + require.NoError(t, err) + + assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain") + assert.Contains(t, resp.ServiceURL, "tls://") + assert.Contains(t, resp.ServiceURL, ":443") + // TLS always keeps its port (not port-based protocol for auto-assign) + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + }) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + // Renew after stop should fail + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.Error(t, err) +} + +func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + Pin: "123456", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go new file mode 100644 index 000000000..ed9d4201b --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -0,0 +1,1278 @@ +package manager + +import ( + "context" + "fmt" + "math/rand/v2" + "net/http" + "os" + "slices" + "strconv" + "time" + + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + defaultAutoAssignPortMin uint16 = 10000 + defaultAutoAssignPortMax uint16 = 49151 + + // EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN" + // EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX" +) + +var ( + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax +) + +func init() { + autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin) + autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax) + if autoAssignPortMin > autoAssignPortMax { + log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults", + EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax) + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax + } +} + +func portFromEnv(key string, fallback uint16) uint16 { + val := os.Getenv(key) + if val == "" { + return fallback + } + n, err := strconv.ParseUint(val, 10, 16) + if err != nil { + log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err) + return fallback + } + return uint16(n) +} + +const unknownHostPlaceholder = "unknown" + +// ClusterDeriver derives the proxy cluster from a domain. +type ClusterDeriver interface { + DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) + GetClusterDomains() []string +} + +// CapabilityProvider queries proxy cluster capabilities from the database. +type CapabilityProvider interface { + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool +} + +type Manager struct { + store store.Store + accountManager account.Manager + permissionsManager permissions.Manager + proxyController proxy.Controller + capabilities CapabilityProvider + clusterDeriver ClusterDeriver + exposeReaper *exposeReaper +} + +// NewManager creates a new service manager. +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { + mgr := &Manager{ + store: store, + accountManager: accountManager, + permissionsManager: permissionsManager, + proxyController: proxyController, + capabilities: capabilities, + clusterDeriver: clusterDeriver, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + return mgr +} + +// StartExposeReaper starts the background goroutine that reaps expired ephemeral services. +func (m *Manager) StartExposeReaper(ctx context.Context) { + m.exposeReaper.StartExposeReaper(ctx) +} + +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetActiveProxyClusters(ctx) +} + +func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get services: %w", err) + } + + for _, service := range services { + err = m.replaceHostByLookup(ctx, accountID, service) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + } + + return services, nil +} + +func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error { + for _, target := range s.Targets { + switch target.TargetType { + case service.TargetTypePeer: + peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) + if err != nil { + log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err) + target.Host = unknownHostPlaceholder + continue + } + target.Host = peer.IP.String() + case service.TargetTypeHost: + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) + if err != nil { + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) + target.Host = unknownHostPlaceholder + continue + } + target.Host = resource.Prefix.Addr().String() + case service.TargetTypeDomain: + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) + if err != nil { + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) + target.Host = unknownHostPlaceholder + continue + } + target.Host = resource.Domain + case service.TargetTypeSubnet: + // For subnets we do not do any lookups on the resource + default: + return fmt.Errorf("unknown target type: %s", target.TargetType) + } + } + + return nil +} + +func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) + if err != nil { + return nil, fmt.Errorf("failed to get service: %w", err) + } + + err = m.replaceHostByLookup(ctx, accountID, service) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + return service, nil +} + +func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { + return nil, err + } + + if err := m.persistNewService(ctx, accountID, s); err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) + + err = m.replaceHostByLookup(ctx, accountID, s) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) + } + + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return s, nil +} + +func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error { + if m.clusterDeriver != nil { + proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain) + return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err) + } + service.ProxyCluster = proxyCluster + + if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil { + return err + } + } + + service.AccountID = accountID + service.InitNewRecord() + + if err := service.Auth.HashSecrets(); err != nil { + return fmt.Errorf("hash secrets: %w", err) + } + + for i, h := range service.Auth.HeaderAuths { + if h != nil && h.Enabled && h.Value == "" { + return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i) + } + } + + keyPair, err := sessionkey.GenerateKeyPair() + if err != nil { + return fmt.Errorf("generate session keys: %w", err) + } + service.SessionPrivateKey = keyPair.PrivateKey + service.SessionPublicKey = keyPair.PublicKey + + return nil +} + +// validateSubdomainRequirement checks whether the domain can be used bare +// (without a subdomain label) on the given cluster. If the cluster reports +// require_subdomain=true and the domain equals the cluster domain, it rejects. +func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error { + if domain != cluster { + return nil + } + requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster) + if requireSub != nil && *requireSub { + return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain) + } + return nil +} + +func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { + customPorts := m.clusterCustomPorts(ctx, svc) + + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if svc.Domain != "" { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + } + + if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil { + return err + } + + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + +// clusterCustomPorts queries whether the cluster supports custom ports. +// Must be called before entering a transaction: the underlying query uses +// the main DB handle, which deadlocks when called inside a transaction +// that already holds the connection. +func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster) +} + +// ensureL4Port auto-assigns a listen port when needed and validates cluster support. +// customPorts must be pre-computed via clusterCustomPorts before entering a transaction. +func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { + if svc.Source != service.SourceEphemeral { + return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) + } + svc.ListenPort = 0 + } + if svc.ListenPort == 0 { + port, err := m.assignPort(ctx, tx, svc.ProxyCluster) + if err != nil { + return err + } + svc.ListenPort = port + svc.PortAutoAssigned = true + } + return nil +} + +// checkPortConflict rejects L4 services that would conflict on the same listener. +// For TCP/UDP: unique per cluster+protocol+port. +// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports). +// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked: +// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener. +func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 { + return nil + } + + existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort) + if err != nil { + return fmt.Errorf("query port conflicts: %w", err) + } + for _, s := range existing { + if s.ID == svc.ID { + continue + } + // TLS services on the same port are allowed if they have different domains (SNI routing) + if svc.Mode == service.ModeTLS && s.Domain != svc.Domain { + continue + } + return status.Errorf(status.AlreadyExists, + "%s port %d is already in use by service %q on cluster %s", + svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster) + } + + return nil +} + +// assignPort picks a random available port on the cluster within the auto-assign range. +func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) { + services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster) + if err != nil { + return 0, fmt.Errorf("query cluster ports: %w", err) + } + + occupied := make(map[uint16]struct{}, len(services)) + for _, s := range services { + if s.ListenPort > 0 { + occupied[s.ListenPort] = struct{}{} + } + } + + portRange := int(autoAssignPortMax-autoAssignPortMin) + 1 + for range 100 { + port := autoAssignPortMin + uint16(rand.IntN(portRange)) + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + for port := autoAssignPortMin; port <= autoAssignPortMax; port++ { + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster) +} + +// persistNewEphemeralService creates an ephemeral service inside a single transaction +// that also enforces the duplicate and per-peer limit checks atomically. +// The count and exists queries use FOR UPDATE locking to serialize concurrent creates +// for the same peer, preventing the per-peer limit from being bypassed. +func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + customPorts := m.clusterCustomPorts(ctx, svc) + + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { + return err + } + + if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil { + return err + } + + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + +func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error { + // Lock the peer row to serialize concurrent creates for the same peer. + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { + return fmt.Errorf("lock peer row: %w", err) + } + + exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) + if err != nil { + return fmt.Errorf("check existing expose: %w", err) + } + if exists { + return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + + count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return fmt.Errorf("count peer exposes: %w", err) + } + if count >= int64(maxExposesPerPeer) { + return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + + return nil +} + +// checkDomainAvailable checks that no other service already uses this domain. +func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { + existingService, err := transaction.GetServiceByDomain(ctx, domain) + if err != nil { + if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { + return fmt.Errorf("check existing service: %w", err) + } + return nil + } + + if existingService != nil && existingService.ID != excludeServiceID { + return status.Errorf(status.AlreadyExists, "domain already taken") + } + + return nil +} + +func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + if err := service.Auth.HashSecrets(); err != nil { + return nil, fmt.Errorf("hash secrets: %w", err) + } + + updateInfo, err := m.persistServiceUpdate(ctx, accountID, service) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta()) + + if err := m.replaceHostByLookup(ctx, accountID, service); err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + + m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return service, nil +} + +type serviceUpdateInfo struct { + oldCluster string + domainChanged bool + serviceEnabledChanged bool +} + +func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) { + effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service) + if err != nil { + return nil, err + } + + svcForCaps := *service + svcForCaps.ProxyCluster = effectiveCluster + customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + + var updateInfo serviceUpdateInfo + + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) + }) + + return &updateInfo, err +} + +// resolveEffectiveCluster determines the cluster that will be used after the update. +// It reads the existing service without locking and derives the new cluster if the domain changed. +func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) { + existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID) + if err != nil { + return "", err + } + + if existing.Domain == svc.Domain { + return existing.ProxyCluster, nil + } + + if m.clusterDeriver != nil { + derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) + } else { + return derived, nil + } + } + + return existing.ProxyCluster, nil +} + +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { + existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) + if err != nil { + return err + } + + if existingService.Terminated { + return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") + } + + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } + + updateInfo.oldCluster = existingService.ProxyCluster + updateInfo.domainChanged = existingService.Domain != service.Domain + + if updateInfo.domainChanged { + if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + return err + } + } else { + service.ProxyCluster = existingService.ProxyCluster + } + + if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { + return err + } + + m.preserveExistingAuthSecrets(service, existingService) + if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { + return err + } + m.preserveServiceMetadata(service, existingService) + m.preserveListenPort(service, existingService) + updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled + + if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil { + return err + } + if err := m.checkPortConflict(ctx, transaction, service); err != nil { + return err + } + if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { + return err + } + if err := transaction.UpdateService(ctx, service); err != nil { + return fmt.Errorf("update service: %w", err) + } + + return nil +} + +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { + return err + } + + if m.clusterDeriver != nil { + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) + } else { + svc.ProxyCluster = newCluster + } + } + + return nil +} + +// validateProtocolChange rejects mode changes on update. +// Only empty<->HTTP is allowed; all other transitions are rejected. +func validateProtocolChange(oldMode, newMode string) error { + if newMode == "" || newMode == oldMode { + return nil + } + if isHTTPFamily(oldMode) && isHTTPFamily(newMode) { + return nil + } + return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode) +} + +func isHTTPFamily(mode string) bool { + return mode == "" || mode == "http" +} + +func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) { + if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled && + existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && + svc.Auth.PasswordAuth.Password == "" { + svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth + } + + if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled && + existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled && + svc.Auth.PinAuth.Pin == "" { + svc.Auth.PinAuth = existingService.Auth.PinAuth + } + + preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths) +} + +// preserveHeaderAuthHashes fills in empty header auth values from the existing +// service so that unchanged secrets are not lost on update. +func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) { + if len(headers) == 0 || len(existing) == 0 { + return + } + existingByHeader := make(map[string]string, len(existing)) + for _, h := range existing { + if h != nil && h.Value != "" { + existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value + } + } + for _, h := range headers { + if h != nil && h.Enabled && h.Value == "" { + if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok { + h.Value = hash + } + } + } +} + +// validateHeaderAuthValues checks that all enabled header auths have a value +// (either freshly provided or preserved from the existing service). +func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error { + for i, h := range headers { + if h != nil && h.Enabled && h.Value == "" { + return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i) + } + } + return nil +} + +func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) { + service.Meta = existingService.Meta + service.SessionPrivateKey = existingService.SessionPrivateKey + service.SessionPublicKey = existingService.SessionPublicKey +} + +func (m *Manager) preserveListenPort(svc, existing *service.Service) { + if existing.ListenPort > 0 && svc.ListenPort == 0 { + svc.ListenPort = existing.ListenPort + svc.PortAutoAssigned = existing.PortAutoAssigned + } +} + +func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { + oidcCfg := m.proxyController.GetOIDCValidationConfig() + + switch { + case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) + case !s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster) + case s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) + default: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster) + } +} + +// validateTargetReferences checks that all target IDs reference existing peers or resources in the account. +func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error { + for _, target := range targets { + switch target.TargetType { + case service.TargetTypePeer: + if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil { + return err + } + case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: + if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { + return err + } + default: + return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) + } + } + return nil +} + +func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + } + return nil +} + +func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + } + return validateResourceTargetType(target, resource) +} + +// validateResourceTargetType checks that target_type matches the actual network resource type. +func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error { + expected := resourcetypes.NetworkResourceType(target.TargetType) + if resource.Type != expected { + return status.Errorf(status.InvalidArgument, + "target %q has target_type %q but resource is of type %q", + target.TargetId, target.TargetType, resource.Type, + ) + } + return nil +} + +func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + var s *service.Service + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return err + } + + if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil { + return fmt.Errorf("failed to delete targets: %w", err) + } + + if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { + return fmt.Errorf("failed to delete service: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta()) + + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + var services []*service.Service + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + for _, svc := range services { + if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil { + return fmt.Errorf("failed to delete service: %w", err) + } + } + + return nil + }) + if err != nil { + return err + } + + oidcCfg := m.proxyController.GetOIDCValidationConfig() + + for _, svc := range services { + m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta()) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) + } + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +// SetCertificateIssuedAt sets the certificate issued timestamp to the current time. +// Call this when receiving a gRPC notification that the certificate was issued. +func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return fmt.Errorf("failed to get service: %w", err) + } + + now := time.Now() + service.Meta.CertificateIssuedAt = &now + + if err = transaction.UpdateService(ctx, service); err != nil { + return fmt.Errorf("failed to update service certificate timestamp: %w", err) + } + + return nil + }) +} + +// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) +func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return fmt.Errorf("failed to get service: %w", err) + } + + service.Meta.Status = string(status) + + if err = transaction.UpdateService(ctx, service); err != nil { + return fmt.Errorf("failed to update service status: %w", err) + } + + return nil + }) +} + +func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error { + s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) + if err != nil { + return fmt.Errorf("failed to get service: %w", err) + } + + err = m.replaceHostByLookup(ctx, accountID, s) + if err != nil { + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) + } + + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { + services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("failed to get services: %w", err) + } + + for _, s := range services { + err = m.replaceHostByLookup(ctx, accountID, s) + if err != nil { + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) + } + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) + } + + return nil +} + +func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { + services, err := m.store.GetServices(ctx, store.LockingStrengthNone) + if err != nil { + return nil, fmt.Errorf("failed to get services: %w", err) + } + + for _, service := range services { + err = m.replaceHostByLookup(ctx, service.AccountID, service) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + } + + return services, nil +} + +func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) { + service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) + if err != nil { + return nil, fmt.Errorf("failed to get service: %w", err) + } + + err = m.replaceHostByLookup(ctx, accountID, service) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + + return service, nil +} + +func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { + services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get services: %w", err) + } + + for _, service := range services { + err = m.replaceHostByLookup(ctx, accountID, service) + if err != nil { + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + } + } + + return services, nil +} + +func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { + target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + return "", nil + } + return "", fmt.Errorf("failed to get service target by resource ID: %w", err) + } + + if target == nil { + return "", nil + } + + return target.ServiceID, nil +} + +// validateExposePermission checks whether the peer is allowed to use the expose feature. +// It verifies the account has peer expose enabled and that the peer belongs to an allowed group. +func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return status.Errorf(status.Internal, "get account settings: %v", err) + } + + if !settings.PeerExposeEnabled { + return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account") + } + + if len(settings.PeerExposeGroups) == 0 { + return status.Errorf(status.PermissionDenied, "no group is set for peer expose") + } + + peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err) + return status.Errorf(status.Internal, "get peer groups: %v", err) + } + + for _, pg := range peerGroupIDs { + if slices.Contains(settings.PeerExposeGroups, pg) { + return nil + } + } + + return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group") +} + +func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) { + return m.buildRandomDomain(serviceName) +} + +// CreateServiceFromPeer creates a service initiated by a peer expose request. +// It validates the request, checks expose permissions, enforces the per-peer limit, +// creates the service, and tracks it for TTL-based reaping. +func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + if err := req.Validate(); err != nil { + return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err) + } + + if err := m.validateExposePermission(ctx, accountID, peerID); err != nil { + return nil, err + } + + serviceName, err := service.GenerateExposeName(req.NamePrefix) + if err != nil { + return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err) + } + + svc := req.ToService(accountID, peerID, serviceName) + svc.Source = service.SourceEphemeral + + if svc.Domain == "" { + domain, err := m.resolveDefaultDomain(svc.Name) + if err != nil { + return nil, err + } + svc.Domain = domain + } + + if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled { + groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups) + if err != nil { + return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err) + } + svc.Auth.BearerAuth.DistributionGroups = groupIDs + } + + if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil { + return nil, err + } + + peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return nil, err + } + + svc.SourcePeer = peerID + + now := time.Now() + svc.Meta.LastRenewedAt = &now + + if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil { + return nil, err + } + + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) + m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta) + + if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil { + return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err) + } + + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) + m.accountManager.UpdateAccountPeers(ctx, accountID) + + serviceURL := "https://" + svc.Domain + if service.IsL4Protocol(svc.Mode) { + serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + } + + return &service.ExposeServiceResponse{ + ServiceName: svc.Name, + ServiceURL: serviceURL, + Domain: svc.Domain, + PortAutoAssigned: svc.PortAutoAssigned, + }, nil +} + +func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) { + if len(groupNames) == 0 { + return []string{}, fmt.Errorf("no group names provided") + } + groupIDs := make([]string, 0, len(groupNames)) + for _, groupName := range groupNames { + g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator) + if err != nil { + return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) + } + groupIDs = append(groupIDs, g.ID) + } + return groupIDs, nil +} + +func (m *Manager) getDefaultClusterDomain() (string, error) { + if m.clusterDeriver == nil { + return "", fmt.Errorf("unable to get cluster domain") + } + clusterDomains := m.clusterDeriver.GetClusterDomains() + if len(clusterDomains) == 0 { + return "", fmt.Errorf("no cluster domains available") + } + return clusterDomains[rand.IntN(len(clusterDomains))], nil +} + +func (m *Manager) buildRandomDomain(name string) (string, error) { + domain, err := m.getDefaultClusterDomain() + if err != nil { + return "", err + } + return name + "." + domain, nil +} + +// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service. +func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID) +} + +// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB. +func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err) + return err + } + return nil +} + +// deleteServiceFromPeer deletes a peer-initiated service identified by service ID. +// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. +func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error { + activityCode := activity.PeerServiceUnexposed + if expired { + activityCode = activity.PeerServiceExposeExpired + } + return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode) +} + +func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { + var svc *service.Service + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return err + } + + if svc.Source != service.SourceEphemeral { + return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose") + } + + if svc.SourcePeer != peerID { + return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer") + } + + if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { + return fmt.Errorf("delete service: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err) + peer = nil + } + + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) + + m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta) + + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking +// that it is still expired under a row lock. This prevents deleting a service +// that was renewed between the batch query and this delete, and ensures only one +// management instance processes the deletion +func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error { + var svc *service.Service + deleted := false + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return err + } + + if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID { + return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner") + } + + if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL { + return nil + } + + if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { + return fmt.Errorf("delete service: %w", err) + } + deleted = true + + return nil + }) + if err != nil { + return err + } + + if !deleted { + return nil + } + + peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err) + peer = nil + } + + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) + m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any { + if peer == nil { + return meta + } + meta["peer_name"] = peer.Name + if peer.IP != nil { + meta["peer_ip"] = peer.IP.String() + } + return meta +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go new file mode 100644 index 000000000..54ac8ab18 --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -0,0 +1,1343 @@ +package manager + +import ( + "context" + "errors" + "net" + "testing" + "time" + + cachestore "github.com/eko/gocache/lib/v4/store" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/mock_server" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + +func TestInitializeServiceForCreate(t *testing.T) { + ctx := context.Background() + accountID := "test-account" + + t.Run("successful initialization without cluster deriver", func(t *testing.T) { + mgr := &Manager{ + clusterDeriver: nil, + } + + service := &rpservice.Service{ + Domain: "example.com", + Auth: rpservice.AuthConfig{}, + } + + err := mgr.initializeServiceForCreate(ctx, accountID, service) + + assert.NoError(t, err) + assert.Equal(t, accountID, service.AccountID) + assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver") + assert.NotEmpty(t, service.ID, "service ID should be initialized") + assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated") + assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated") + }) + + t.Run("verifies session keys are different", func(t *testing.T) { + mgr := &Manager{ + clusterDeriver: nil, + } + + service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}} + service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}} + + err1 := mgr.initializeServiceForCreate(ctx, accountID, service1) + err2 := mgr.initializeServiceForCreate(ctx, accountID, service2) + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique") + assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique") + }) +} + +func TestCheckDomainAvailable(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + domain string + excludeServiceID string + setupMock func(*store.MockStore) + expectedError bool + errorType status.Type + }{ + { + name: "domain available - not found", + domain: "available.com", + excludeServiceID: "", + setupMock: func(ms *store.MockStore) { + ms.EXPECT(). + GetServiceByDomain(ctx, "available.com"). + Return(nil, status.Errorf(status.NotFound, "not found")) + }, + expectedError: false, + }, + { + name: "domain already exists", + domain: "exists.com", + excludeServiceID: "", + setupMock: func(ms *store.MockStore) { + ms.EXPECT(). + GetServiceByDomain(ctx, "exists.com"). + Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil) + }, + expectedError: true, + errorType: status.AlreadyExists, + }, + { + name: "domain exists but excluded (same ID)", + domain: "exists.com", + excludeServiceID: "service-123", + setupMock: func(ms *store.MockStore) { + ms.EXPECT(). + GetServiceByDomain(ctx, "exists.com"). + Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) + }, + expectedError: false, + }, + { + name: "domain exists with different ID", + domain: "exists.com", + excludeServiceID: "service-456", + setupMock: func(ms *store.MockStore) { + ms.EXPECT(). + GetServiceByDomain(ctx, "exists.com"). + Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) + }, + expectedError: true, + errorType: status.AlreadyExists, + }, + { + name: "store error (non-NotFound)", + domain: "error.com", + excludeServiceID: "", + setupMock: func(ms *store.MockStore) { + ms.EXPECT(). + GetServiceByDomain(ctx, "error.com"). + Return(nil, errors.New("database error")) + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + tt.setupMock(mockStore) + + mgr := &Manager{} + err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID) + + if tt.expectedError { + require.Error(t, err) + if tt.errorType != 0 { + sErr, ok := status.FromError(err) + require.True(t, ok, "error should be a status error") + assert.Equal(t, tt.errorType, sErr.Type()) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCheckDomainAvailable_EdgeCases(t *testing.T) { + ctx := context.Background() + + t.Run("empty domain", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetServiceByDomain(ctx, ""). + Return(nil, status.Errorf(status.NotFound, "not found")) + + mgr := &Manager{} + err := mgr.checkDomainAvailable(ctx, mockStore, "", "") + + assert.NoError(t, err) + }) + + t.Run("empty exclude ID with existing service", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetServiceByDomain(ctx, "test.com"). + Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil) + + mgr := &Manager{} + err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "") + + assert.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.AlreadyExists, sErr.Type()) + }) + + t.Run("nil existing service with nil error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetServiceByDomain(ctx, "nil.com"). + Return(nil, nil) + + mgr := &Manager{} + err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "") + + assert.NoError(t, err) + }) +} + +func TestPersistNewService(t *testing.T) { + ctx := context.Background() + accountID := "test-account" + + t.Run("successful service creation with no targets", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + service := &rpservice.Service{ + ID: "service-123", + Domain: "new.com", + Targets: []*rpservice.Target{}, + } + + // Mock ExecuteInTransaction to execute the function immediately + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + // Create another mock for the transaction + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByDomain(ctx, "new.com"). + Return(nil, status.Errorf(status.NotFound, "not found")) + txMock.EXPECT(). + CreateService(ctx, service). + Return(nil) + + return fn(txMock) + }) + + mgr := &Manager{store: mockStore} + err := mgr.persistNewService(ctx, accountID, service) + + assert.NoError(t, err) + }) + + t.Run("domain already exists", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + service := &rpservice.Service{ + ID: "service-123", + Domain: "existing.com", + Targets: []*rpservice.Target{}, + } + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByDomain(ctx, "existing.com"). + Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil) + + return fn(txMock) + }) + + mgr := &Manager{store: mockStore} + err := mgr.persistNewService(ctx, accountID, service) + + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.AlreadyExists, sErr.Type()) + }) +} +func TestPreserveExistingAuthSecrets(t *testing.T) { + mgr := &Manager{} + + t.Run("preserve password when empty", func(t *testing.T) { + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ + Enabled: true, + Password: "hashed-password", + }, + }, + } + + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ + Enabled: true, + Password: "", + }, + }, + } + + mgr.preserveExistingAuthSecrets(updated, existing) + + assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth) + }) + + t.Run("preserve pin when empty", func(t *testing.T) { + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PinAuth: &rpservice.PINAuthConfig{ + Enabled: true, + Pin: "hashed-pin", + }, + }, + } + + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PinAuth: &rpservice.PINAuthConfig{ + Enabled: true, + Pin: "", + }, + }, + } + + mgr.preserveExistingAuthSecrets(updated, existing) + + assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth) + }) + + t.Run("do not preserve when password is provided", func(t *testing.T) { + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ + Enabled: true, + Password: "old-password", + }, + }, + } + + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ + Enabled: true, + Password: "new-password", + }, + }, + } + + mgr.preserveExistingAuthSecrets(updated, existing) + + assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password) + assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth) + }) +} + +func TestPreserveServiceMetadata(t *testing.T) { + mgr := &Manager{} + + existing := &rpservice.Service{ + Meta: rpservice.Meta{ + CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(), + Status: "active", + }, + SessionPrivateKey: "private-key", + SessionPublicKey: "public-key", + } + + updated := &rpservice.Service{ + Domain: "updated.com", + } + + mgr.preserveServiceMetadata(updated, existing) + + assert.Equal(t, existing.Meta, updated.Meta) + assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey) + assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey) +} + +func TestDeletePeerService_SourcePeerValidation(t *testing.T) { + ctx := context.Background() + accountID := "test-account" + ownerPeerID := "peer-owner" + otherPeerID := "peer-other" + serviceID := "service-123" + + testPeer := &nbpeer.Peer{ + ID: ownerPeerID, + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + } + + newEphemeralService := func() *rpservice.Service { + return &rpservice.Service{ + ID: serviceID, + AccountID: accountID, + Name: "test-service", + Domain: "test.example.com", + Source: rpservice.SourceEphemeral, + SourcePeer: ownerPeerID, + } + } + + newPermanentService := func() *rpservice.Service { + return &rpservice.Service{ + ID: serviceID, + AccountID: accountID, + Name: "api-service", + Domain: "api.example.com", + Source: rpservice.SourcePermanent, + } + } + + newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { + t.Helper() + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + return srv + } + + t.Run("owner peer can delete own service", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var storedActivity activity.Activity + mockStore := store.NewMockStore(ctrl) + mockAccountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { + storedActivity = activityID.(activity.Activity) + }, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + } + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID). + Return(newEphemeralService(), nil) + txMock.EXPECT(). + DeleteService(ctx, accountID, serviceID). + Return(nil) + return fn(txMock) + }) + mockStore.EXPECT(). + GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). + Return(testPeer, nil) + + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), + } + + err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) + require.NoError(t, err) + assert.Equal(t, activity.PeerServiceUnexposed, storedActivity, "should store unexposed activity") + }) + + t.Run("different peer cannot delete service", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID). + Return(newEphemeralService(), nil) + return fn(txMock) + }) + + mgr := &Manager{ + store: mockStore, + } + + err := mgr.deletePeerService(ctx, accountID, otherPeerID, serviceID, activity.PeerServiceUnexposed) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok, "should be a status error") + assert.Equal(t, status.PermissionDenied, sErr.Type(), "should be permission denied") + assert.Contains(t, err.Error(), "another peer") + }) + + t.Run("cannot delete API-created service", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID). + Return(newPermanentService(), nil) + return fn(txMock) + }) + + mgr := &Manager{ + store: mockStore, + } + + err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok, "should be a status error") + assert.Equal(t, status.PermissionDenied, sErr.Type(), "should be permission denied") + assert.Contains(t, err.Error(), "API-created") + }) + + t.Run("expire uses correct activity code", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var storedActivity activity.Activity + mockStore := store.NewMockStore(ctrl) + mockAccountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { + storedActivity = activityID.(activity.Activity) + }, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + } + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID). + Return(newEphemeralService(), nil) + txMock.EXPECT(). + DeleteService(ctx, accountID, serviceID). + Return(nil) + return fn(txMock) + }) + mockStore.EXPECT(). + GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). + Return(testPeer, nil) + + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), + } + + err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired) + require.NoError(t, err) + assert.Equal(t, activity.PeerServiceExposeExpired, storedActivity, "should store expired activity") + }) + + t.Run("event meta includes peer info", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var storedMeta map[string]any + mockStore := store.NewMockStore(ctrl) + mockAccountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) { + storedMeta = meta + }, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + } + + mockStore.EXPECT(). + ExecuteInTransaction(ctx, gomock.Any()). + DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { + txMock := store.NewMockStore(ctrl) + txMock.EXPECT(). + GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID). + Return(newEphemeralService(), nil) + txMock.EXPECT(). + DeleteService(ctx, accountID, serviceID). + Return(nil) + return fn(txMock) + }) + mockStore.EXPECT(). + GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). + Return(testPeer, nil) + + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), + } + + err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) + require.NoError(t, err) + require.NotNil(t, storedMeta) + assert.Equal(t, "test-peer", storedMeta["peer_name"], "meta should contain peer name") + assert.Equal(t, "100.64.0.1", storedMeta["peer_ip"], "meta should contain peer IP") + assert.Equal(t, "test-service", storedMeta["name"], "meta should contain service name") + assert.Equal(t, "test.example.com", storedMeta["domain"], "meta should contain service domain") + }) +} + +// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list. +type testClusterDeriver struct { + domains []string +} + +func (d *testClusterDeriver) DeriveClusterFromDomain(_ context.Context, _, domain string) (string, error) { + return "test-cluster", nil +} + +func (d *testClusterDeriver) GetClusterDomains() []string { + return d.domains +} + +const ( + testAccountID = "test-account" + testPeerID = "test-peer-1" + testGroupID = "test-group-1" + testUserID = "test-user" +) + +// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests. +func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { + t.Helper() + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + CreatedBy: testUserID, + Settings: &types.Settings{ + PeerExposeEnabled: true, + PeerExposeGroups: []string{testGroupID}, + }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, + Peers: map[string]*nbpeer.Peer{ + testPeerID: { + ID: testPeerID, + AccountID: testAccountID, + Key: "test-key", + DNSLabel: "test-peer", + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, + }, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + AccountID: testAccountID, + Name: "Expose Group", + }, + }, + }) + require.NoError(t, err) + + err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) + require.NoError(t, err) + + permsMgr := permissions.NewManager(testStore) + + accountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) + }, + } + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + + proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + + mgr := &Manager{ + store: testStore, + accountManager: accountMgr, + permissionsManager: permsMgr, + proxyController: proxyController, + clusterDeriver: &testClusterDeriver{ + domains: []string{"test.netbird.io"}, + }, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + + return mgr, testStore +} + +func Test_validateExposePermission(t *testing.T) { + ctx := context.Background() + + t.Run("allowed when peer is in expose group", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) + assert.NoError(t, err) + }) + + t.Run("denied when peer is not in expose group", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // Add a peer that is NOT in the expose group + otherPeerID := "other-peer" + err := testStore.AddPeerToAccount(ctx, &nbpeer.Peer{ + ID: otherPeerID, + AccountID: testAccountID, + Key: "other-key", + DNSLabel: "other-peer", + Name: "other-peer", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "other-peer"}, + }) + require.NoError(t, err) + + err = mgr.validateExposePermission(ctx, testAccountID, otherPeerID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not in an allowed expose group") + }) + + t.Run("denied when expose is disabled", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // Disable peer expose + s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + s.PeerExposeEnabled = false + err = testStore.SaveAccountSettings(ctx, testAccountID, s) + require.NoError(t, err) + + err = mgr.validateExposePermission(ctx, testAccountID, testPeerID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not enabled") + }) + + t.Run("disallowed when no groups configured", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // Enable expose with empty groups — no groups configured means no peer is allowed + s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + s.PeerExposeGroups = []string{} + err = testStore.SaveAccountSettings(ctx, testAccountID, s) + require.NoError(t, err) + + err = mgr.validateExposePermission(ctx, testAccountID, testPeerID) + assert.Error(t, err) + }) + + t.Run("error when store returns error", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error")) + mgr := &Manager{store: mockStore} + err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) + require.Error(t, err) + assert.Contains(t, err.Error(), "get account settings") + }) +} + +func TestCreateServiceFromPeer(t *testing.T) { + ctx := context.Background() + + t.Run("creates service with random domain", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + req := &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + assert.NotEmpty(t, resp.ServiceName, "service name should be generated") + assert.Contains(t, resp.Domain, "test.netbird.io", "domain should use cluster domain") + assert.NotEmpty(t, resp.ServiceURL, "service URL should be set") + + // Verify service is persisted in store + persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + assert.Equal(t, resp.Domain, persisted.Domain) + assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral") + assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set") + assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set") + }) + + t.Run("creates service with custom domain", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + + req := &rpservice.ExposeServiceRequest{ + Port: 80, + Mode: "http", + Domain: "example.com", + } + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + assert.Contains(t, resp.Domain, "example.com", "should use the provided domain") + }) + + t.Run("validates expose permission internally", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // Disable peer expose + s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + s.PeerExposeEnabled = false + err = testStore.SaveAccountSettings(ctx, testAccountID, s) + require.NoError(t, err) + + req := &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + + _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.Error(t, err) + assert.Contains(t, err.Error(), "not enabled") + }) + + t.Run("validates request fields", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + + req := &rpservice.ExposeServiceRequest{ + Port: 0, + Mode: "http", + } + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.Error(t, err) + assert.Contains(t, err.Error(), "port") + }) +} + +func TestExposeServiceRequestValidate(t *testing.T) { + tests := []struct { + name string + req rpservice.ExposeServiceRequest + wantErr string + }{ + { + name: "valid http request", + req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"}, + wantErr: "", + }, + { + name: "https mode rejected", + req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"}, + wantErr: "unsupported mode", + }, + { + name: "port zero rejected", + req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"}, + wantErr: "port must be between 1 and 65535", + }, + { + name: "unsupported mode", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"}, + wantErr: "unsupported mode", + }, + { + name: "invalid pin format", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"}, + wantErr: "invalid pin", + }, + { + name: "pin too short", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"}, + wantErr: "invalid pin", + }, + { + name: "valid 6-digit pin", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"}, + wantErr: "", + }, + { + name: "empty user group name", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}}, + wantErr: "user group name cannot be empty", + }, + { + name: "invalid name prefix", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"}, + wantErr: "invalid name prefix", + }, + { + name: "valid name prefix", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"}, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } + + t.Run("nil receiver", func(t *testing.T) { + var req *rpservice.ExposeServiceRequest + err := req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "request cannot be nil") + }) +} + +func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { + ctx := context.Background() + + t.Run("deletes service by domain", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // First create a service + req := &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false) + require.NoError(t, err) + + // Verify service is deleted + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) + require.Error(t, err, "service should be deleted") + }) + + t.Run("expire uses correct activity", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + req := &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true) + require.NoError(t, err) + }) +} + +func TestStopServiceFromPeer(t *testing.T) { + ctx := context.Background() + + t.Run("stops service by domain", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + req := &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) + require.Error(t, err, "service should be deleted") + }) +} + +func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { + ctx := context.Background() + mgr, testStore := setupIntegrationTest(t) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + }) + require.NoError(t, err) + + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(1), count, "one ephemeral service should exist after create") + + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + + err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) + require.NoError(t, err) + + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete") + + _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 9090, + Mode: "http", + }) + assert.NoError(t, err, "new expose should succeed after API delete") +} + +func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) { + ctx := context.Background() + mgr, _ := setupIntegrationTest(t) + + for i := range 3 { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: uint16(8080 + i), + Mode: "http", + }) + require.NoError(t, err) + } + + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(3), count, "all ephemeral services should exist") + + err = mgr.DeleteAllServices(ctx, testAccountID, testUserID) + require.NoError(t, err) + + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices") +} + +func TestRenewServiceFromPeer(t *testing.T) { + ctx := context.Background() + + t.Run("renews tracked expose", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "http", + }) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + }) + + t.Run("fails for untracked domain", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") + require.Error(t, err) + }) +} + +func TestGetGroupIDsFromNames(t *testing.T) { + ctx := context.Background() + + t.Run("resolves group names to IDs", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ids, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{"Expose Group"}) + require.NoError(t, err) + require.Len(t, ids, 1, "should return exactly one group ID") + assert.Equal(t, testGroupID, ids[0]) + }) + + t.Run("returns error for unknown group", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + _, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{"nonexistent"}) + require.Error(t, err) + }) + + t.Run("returns error for empty group list", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + _, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no group names provided") + }) +} + +func TestDeleteService_DeletesTargets(t *testing.T) { + ctx := context.Background() + accountID := "test-account" + userID := "test-user" + + sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPerms := permissions.NewMockManager(ctrl) + mockAcct := account.NewMockManager(ctrl) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + + proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + + mgr := &Manager{ + store: sqlStore, + permissionsManager: mockPerms, + accountManager: mockAcct, + proxyController: proxyController, + } + + service := &rpservice.Service{ + ID: "service-1", + AccountID: accountID, + Domain: "test.example.com", + ProxyCluster: "cluster1", + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"}, + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"}, + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"}, + }, + } + + err = sqlStore.CreateService(ctx, service) + require.NoError(t, err) + + retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID) + require.NoError(t, err) + require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion") + + mockPerms.EXPECT(). + ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete). + Return(true, nil) + mockAcct.EXPECT(). + StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) + mockAcct.EXPECT(). + UpdateAccountPeers(ctx, accountID) + + err = mgr.DeleteService(ctx, accountID, userID, service.ID) + require.NoError(t, err) + + _, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID) + require.Error(t, err) + s, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, s.Type()) + + targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID) + require.NoError(t, err) + assert.Len(t, targets, 0, "All targets should be deleted when service is deleted") +} + +func TestValidateProtocolChange(t *testing.T) { + tests := []struct { + name string + oldP string + newP string + wantErr bool + }{ + {"empty to http", "", "http", false}, + {"http to http", "http", "http", false}, + {"same protocol", "tcp", "tcp", false}, + {"empty new proto", "tcp", "", false}, + {"http to tcp", "http", "tcp", true}, + {"tcp to udp", "tcp", "udp", true}, + {"tls to http", "tls", "http", true}, + {"udp to tls", "udp", "tls", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProtocolChange(tt.oldP, tt.newP) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot change mode") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + tests := []struct { + name string + targetType rpservice.TargetType + resourceType resourcetypes.NetworkResourceType + wantErr bool + }{ + {"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false}, + {"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false}, + {"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false}, + {"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true}, + {"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true}, + {"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true}, + {"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore.EXPECT(). + GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1"). + Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil) + + targets := []*rpservice.Target{ + {TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"}, + } + err := validateTargetReferences(ctx, mockStore, accountID, targets) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "target_type") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateTargetReferences_PeerValid(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + mockStore.EXPECT(). + GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1"). + Return(&nbpeer.Peer{}, nil) + + targets := []*rpservice.Target{ + {TargetId: "peer-1", TargetType: rpservice.TargetTypePeer}, + } + require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets)) +} + +func TestValidateSubdomainRequirement(t *testing.T) { + ptrBool := func(b bool) *bool { return &b } + + tests := []struct { + name string + domain string + cluster string + requireSubdomain *bool + wantErr bool + }{ + { + name: "subdomain present, require_subdomain true", + domain: "app.eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: false, + }, + { + name: "bare cluster domain, require_subdomain true", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: true, + }, + { + name: "bare cluster domain, require_subdomain false", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(false), + wantErr: false, + }, + { + name: "bare cluster domain, require_subdomain nil (default)", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: nil, + wantErr: false, + }, + { + name: "custom domain apex is not the cluster", + domain: "example.com", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes() + + mgr := &Manager{capabilities: mockCaps} + err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster) + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "requires a subdomain label") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go new file mode 100644 index 000000000..769e037bc --- /dev/null +++ b/management/internals/modules/reverseproxy/service/service.go @@ -0,0 +1,1390 @@ +package service + +import ( + "crypto/rand" + "errors" + "fmt" + "math/big" + "net" + "net/http" + "net/netip" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "github.com/rs/xid" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/shared/hash/argon2id" + "github.com/netbirdio/netbird/util/crypt" + + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type Operation string + +const ( + Create Operation = "create" + Update Operation = "update" + Delete Operation = "delete" +) + +type Status string +type TargetType string + +const ( + StatusPending Status = "pending" + StatusActive Status = "active" + StatusTunnelNotCreated Status = "tunnel_not_created" + StatusCertificatePending Status = "certificate_pending" + StatusCertificateFailed Status = "certificate_failed" + StatusError Status = "error" + + TargetTypePeer TargetType = "peer" + TargetTypeHost TargetType = "host" + TargetTypeDomain TargetType = "domain" + TargetTypeSubnet TargetType = "subnet" + + SourcePermanent = "permanent" + SourceEphemeral = "ephemeral" +) + +type TargetOptions struct { + SkipTLSVerify bool `json:"skip_tls_verify"` + RequestTimeout time.Duration `json:"request_timeout,omitempty"` + SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` + PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` +} + +type Target struct { + ID uint `gorm:"primaryKey" json:"-"` + AccountID string `gorm:"index:idx_target_account;not null" json:"-"` + ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` + Path *string `json:"path,omitempty"` + Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Port uint16 `gorm:"index:idx_target_port" json:"port"` + Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` + TargetId string `gorm:"index:idx_target_id" json:"target_id"` + TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"` + Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` + Options TargetOptions `gorm:"embedded" json:"options"` + ProxyProtocol bool `json:"proxy_protocol"` +} + +type PasswordAuthConfig struct { + Enabled bool `json:"enabled"` + Password string `json:"password"` +} + +type PINAuthConfig struct { + Enabled bool `json:"enabled"` + Pin string `json:"pin"` +} + +type BearerAuthConfig struct { + Enabled bool `json:"enabled"` + DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"` +} + +// HeaderAuthConfig defines a static header-value auth check. +// The proxy compares the incoming header value against the stored hash. +type HeaderAuthConfig struct { + Enabled bool `json:"enabled"` + Header string `json:"header"` + Value string `json:"value"` +} + +type AuthConfig struct { + PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"` + PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"` + BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"` + HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"` +} + +// AccessRestrictions controls who can connect to the service based on IP or geography. +type AccessRestrictions struct { + AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"` + BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"` + AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"` + BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"` + CrowdSecMode string `json:"crowdsec_mode,omitempty" gorm:"serializer:json"` +} + +// Copy returns a deep copy of the AccessRestrictions. +func (r AccessRestrictions) Copy() AccessRestrictions { + return AccessRestrictions{ + AllowedCIDRs: slices.Clone(r.AllowedCIDRs), + BlockedCIDRs: slices.Clone(r.BlockedCIDRs), + AllowedCountries: slices.Clone(r.AllowedCountries), + BlockedCountries: slices.Clone(r.BlockedCountries), + CrowdSecMode: r.CrowdSecMode, + } +} + +func (a *AuthConfig) HashSecrets() error { + if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" { + hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + a.PasswordAuth.Password = hashedPassword + } + + if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" { + hashedPin, err := argon2id.Hash(a.PinAuth.Pin) + if err != nil { + return fmt.Errorf("hash pin: %w", err) + } + a.PinAuth.Pin = hashedPin + } + + for i, h := range a.HeaderAuths { + if h != nil && h.Enabled && h.Value != "" { + hashedValue, err := argon2id.Hash(h.Value) + if err != nil { + return fmt.Errorf("hash header auth[%d] value: %w", i, err) + } + h.Value = hashedValue + } + } + + return nil +} + +func (a *AuthConfig) ClearSecrets() { + if a.PasswordAuth != nil { + a.PasswordAuth.Password = "" + } + if a.PinAuth != nil { + a.PinAuth.Pin = "" + } + for _, h := range a.HeaderAuths { + if h != nil { + h.Value = "" + } + } +} + +type Meta struct { + CreatedAt time.Time + CertificateIssuedAt *time.Time + Status string + LastRenewedAt *time.Time +} + +type Service struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Domain string `gorm:"type:varchar(255);uniqueIndex"` + ProxyCluster string `gorm:"index"` + Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` + Enabled bool + Terminated bool + PassHostHeader bool + RewriteRedirects bool + Auth AuthConfig `gorm:"serializer:json"` + Restrictions AccessRestrictions `gorm:"serializer:json"` + Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` + SessionPrivateKey string `gorm:"column:session_private_key"` + SessionPublicKey string `gorm:"column:session_public_key"` + Source string `gorm:"default:'permanent';index:idx_service_source_peer"` + SourcePeer string `gorm:"index:idx_service_source_peer"` + // Mode determines the service type: "http", "tcp", "udp", or "tls". + Mode string `gorm:"default:'http'"` + ListenPort uint16 + PortAutoAssigned bool +} + +// InitNewRecord generates a new unique ID and resets metadata for a newly created +// Service record. This overwrites any existing ID and Meta fields and should +// only be called during initial creation, not for updates. +func (s *Service) InitNewRecord() { + s.ID = xid.New().String() + s.Meta = Meta{ + CreatedAt: time.Now(), + Status: string(StatusPending), + } +} + +func (s *Service) ToAPIResponse() *api.Service { + authConfig := api.ServiceAuthConfig{} + + if s.Auth.PasswordAuth != nil { + authConfig.PasswordAuth = &api.PasswordAuthConfig{ + Enabled: s.Auth.PasswordAuth.Enabled, + } + } + + if s.Auth.PinAuth != nil { + authConfig.PinAuth = &api.PINAuthConfig{ + Enabled: s.Auth.PinAuth.Enabled, + } + } + + if s.Auth.BearerAuth != nil { + authConfig.BearerAuth = &api.BearerAuthConfig{ + Enabled: s.Auth.BearerAuth.Enabled, + DistributionGroups: &s.Auth.BearerAuth.DistributionGroups, + } + } + + if len(s.Auth.HeaderAuths) > 0 { + apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths)) + for _, h := range s.Auth.HeaderAuths { + if h == nil { + continue + } + apiHeaders = append(apiHeaders, api.HeaderAuthConfig{ + Enabled: h.Enabled, + Header: h.Header, + }) + } + authConfig.HeaderAuths = &apiHeaders + } + + // Convert internal targets to API targets + apiTargets := make([]api.ServiceTarget, 0, len(s.Targets)) + for _, target := range s.Targets { + st := api.ServiceTarget{ + Path: target.Path, + Host: &target.Host, + Port: int(target.Port), + Protocol: api.ServiceTargetProtocol(target.Protocol), + TargetId: target.TargetId, + TargetType: api.ServiceTargetTargetType(target.TargetType), + Enabled: target.Enabled && !s.Terminated, + } + opts := targetOptionsToAPI(target.Options) + if opts == nil { + opts = &api.ServiceTargetOptions{} + } + if target.ProxyProtocol { + opts.ProxyProtocol = &target.ProxyProtocol + } + st.Options = opts + apiTargets = append(apiTargets, st) + } + + meta := api.ServiceMeta{ + CreatedAt: s.Meta.CreatedAt, + Status: api.ServiceMetaStatus(s.Meta.Status), + } + + if s.Meta.CertificateIssuedAt != nil { + meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt + } + + mode := api.ServiceMode(s.Mode) + listenPort := int(s.ListenPort) + + resp := &api.Service{ + Id: s.ID, + Name: s.Name, + Domain: s.Domain, + Targets: apiTargets, + Enabled: s.Enabled && !s.Terminated, + Terminated: &s.Terminated, + PassHostHeader: &s.PassHostHeader, + RewriteRedirects: &s.RewriteRedirects, + Auth: authConfig, + AccessRestrictions: restrictionsToAPI(s.Restrictions), + Meta: meta, + Mode: &mode, + ListenPort: &listenPort, + PortAutoAssigned: &s.PortAutoAssigned, + } + + if s.ProxyCluster != "" { + resp.ProxyCluster = &s.ProxyCluster + } + + return resp +} + +func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { + pathMappings := s.buildPathMappings() + + auth := &proto.Authentication{ + SessionKey: s.SessionPublicKey, + MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()), + } + + if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled { + auth.Password = true + } + + if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled { + auth.Pin = true + } + + if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled { + auth.Oidc = true + } + + for _, h := range s.Auth.HeaderAuths { + if h != nil && h.Enabled { + auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{ + Header: h.Header, + HashedValue: h.Value, + }) + } + } + + mapping := &proto.ProxyMapping{ + Type: operationToProtoType(operation), + Id: s.ID, + Domain: s.Domain, + Path: pathMappings, + AuthToken: authToken, + Auth: auth, + AccountId: s.AccountID, + PassHostHeader: s.PassHostHeader, + RewriteRedirects: s.RewriteRedirects, + Mode: s.Mode, + ListenPort: int32(s.ListenPort), //nolint:gosec + } + + if r := restrictionsToProto(s.Restrictions); r != nil { + mapping.AccessRestrictions = r + } + + return mapping +} + +// buildPathMappings constructs PathMapping entries from targets. +// For HTTP/HTTPS, each target becomes a path-based route with a full URL. +// For L4/TLS, a single target maps to a host:port address. +func (s *Service) buildPathMappings() []*proto.PathMapping { + pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) + for _, target := range s.Targets { + if !target.Enabled { + continue + } + + if IsL4Protocol(s.Mode) { + pm := &proto.PathMapping{ + Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)), + } + opts := l4TargetOptionsToProto(target) + if opts != nil { + pm.Options = opts + } + pathMappings = append(pathMappings, pm) + continue + } + + // HTTP/HTTPS: build full URL + targetURL := url.URL{ + Scheme: target.Protocol, + Host: target.Host, + Path: "/", + } + if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { + targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + } + + path := "/" + if target.Path != nil { + path = *target.Path + } + + pm := &proto.PathMapping{ + Path: path, + Target: targetURL.String(), + } + pm.Options = targetOptionsToProto(target.Options) + pathMappings = append(pathMappings, pm) + } + return pathMappings +} + +func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { + switch op { + case Create: + return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED + case Update: + return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED + case Delete: + return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED + default: + panic(fmt.Sprintf("unknown operation type: %v", op)) + } +} + +// isDefaultPort reports whether port is the standard default for the given scheme +// (443 for https, 80 for http). +func isDefaultPort(scheme string, port uint16) bool { + return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80) +} + +// PathRewriteMode controls how the request path is rewritten before forwarding. +type PathRewriteMode string + +const ( + PathRewritePreserve PathRewriteMode = "preserve" +) + +func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode { + switch mode { + case PathRewritePreserve: + return proto.PathRewriteMode_PATH_REWRITE_PRESERVE + default: + return proto.PathRewriteMode_PATH_REWRITE_DEFAULT + } +} + +func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { + if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { + return nil + } + apiOpts := &api.ServiceTargetOptions{} + if opts.SkipTLSVerify { + apiOpts.SkipTlsVerify = &opts.SkipTLSVerify + } + if opts.RequestTimeout != 0 { + s := opts.RequestTimeout.String() + apiOpts.RequestTimeout = &s + } + if opts.SessionIdleTimeout != 0 { + s := opts.SessionIdleTimeout.String() + apiOpts.SessionIdleTimeout = &s + } + if opts.PathRewrite != "" { + pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite) + apiOpts.PathRewrite = &pr + } + if len(opts.CustomHeaders) > 0 { + apiOpts.CustomHeaders = &opts.CustomHeaders + } + return apiOpts +} + +func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { + if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 { + return nil + } + popts := &proto.PathTargetOptions{ + SkipTlsVerify: opts.SkipTLSVerify, + PathRewrite: pathRewriteToProto(opts.PathRewrite), + CustomHeaders: opts.CustomHeaders, + } + if opts.RequestTimeout != 0 { + popts.RequestTimeout = durationpb.New(opts.RequestTimeout) + } + return popts +} + +// l4TargetOptionsToProto converts L4-relevant target options to proto. +func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions { + if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 { + return nil + } + opts := &proto.PathTargetOptions{ + ProxyProtocol: target.ProxyProtocol, + } + if target.Options.RequestTimeout > 0 { + opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout) + } + if target.Options.SessionIdleTimeout > 0 { + opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout) + } + return opts +} + +func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) { + var opts TargetOptions + if o.SkipTlsVerify != nil { + opts.SkipTLSVerify = *o.SkipTlsVerify + } + if o.RequestTimeout != nil { + d, err := time.ParseDuration(*o.RequestTimeout) + if err != nil { + return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err) + } + opts.RequestTimeout = d + } + if o.SessionIdleTimeout != nil { + d, err := time.ParseDuration(*o.SessionIdleTimeout) + if err != nil { + return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err) + } + opts.SessionIdleTimeout = d + } + if o.PathRewrite != nil { + opts.PathRewrite = PathRewriteMode(*o.PathRewrite) + } + if o.CustomHeaders != nil { + opts.CustomHeaders = *o.CustomHeaders + } + return opts, nil +} + +func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error { + s.Name = req.Name + s.Domain = req.Domain + s.AccountID = accountID + + if req.Mode != nil { + s.Mode = string(*req.Mode) + } + if req.ListenPort != nil { + s.ListenPort = uint16(*req.ListenPort) //nolint:gosec + } + + targets, err := targetsFromAPI(accountID, req.Targets) + if err != nil { + return err + } + s.Targets = targets + s.Enabled = req.Enabled + + if req.PassHostHeader != nil { + s.PassHostHeader = *req.PassHostHeader + } + if req.RewriteRedirects != nil { + s.RewriteRedirects = *req.RewriteRedirects + } + + if req.Auth != nil { + s.Auth = authFromAPI(req.Auth) + } + + if req.AccessRestrictions != nil { + restrictions, err := restrictionsFromAPI(req.AccessRestrictions) + if err != nil { + return err + } + s.Restrictions = restrictions + } + + return nil +} + +func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) { + var apiTargets []api.ServiceTarget + if apiTargetsPtr != nil { + apiTargets = *apiTargetsPtr + } + + targets := make([]*Target, 0, len(apiTargets)) + for i, apiTarget := range apiTargets { + target := &Target{ + AccountID: accountID, + Path: apiTarget.Path, + Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer + Protocol: string(apiTarget.Protocol), + TargetId: apiTarget.TargetId, + TargetType: TargetType(apiTarget.TargetType), + Enabled: apiTarget.Enabled, + } + if apiTarget.Host != nil { + target.Host = *apiTarget.Host + } + if apiTarget.Options != nil { + opts, err := targetOptionsFromAPI(i, apiTarget.Options) + if err != nil { + return nil, err + } + target.Options = opts + if apiTarget.Options.ProxyProtocol != nil { + target.ProxyProtocol = *apiTarget.Options.ProxyProtocol + } + } + targets = append(targets, target) + } + return targets, nil +} + +func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { + var auth AuthConfig + if reqAuth.PasswordAuth != nil { + auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: reqAuth.PasswordAuth.Enabled, + Password: reqAuth.PasswordAuth.Password, + } + } + if reqAuth.PinAuth != nil { + auth.PinAuth = &PINAuthConfig{ + Enabled: reqAuth.PinAuth.Enabled, + Pin: reqAuth.PinAuth.Pin, + } + } + if reqAuth.BearerAuth != nil { + bearerAuth := &BearerAuthConfig{ + Enabled: reqAuth.BearerAuth.Enabled, + } + if reqAuth.BearerAuth.DistributionGroups != nil { + bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups + } + auth.BearerAuth = bearerAuth + } + if reqAuth.HeaderAuths != nil { + for _, h := range *reqAuth.HeaderAuths { + auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{ + Enabled: h.Enabled, + Header: h.Header, + Value: h.Value, + }) + } + } + return auth +} + +func restrictionsFromAPI(r *api.AccessRestrictions) (AccessRestrictions, error) { + if r == nil { + return AccessRestrictions{}, nil + } + var res AccessRestrictions + if r.AllowedCidrs != nil { + res.AllowedCIDRs = *r.AllowedCidrs + } + if r.BlockedCidrs != nil { + res.BlockedCIDRs = *r.BlockedCidrs + } + if r.AllowedCountries != nil { + res.AllowedCountries = *r.AllowedCountries + } + if r.BlockedCountries != nil { + res.BlockedCountries = *r.BlockedCountries + } + if r.CrowdsecMode != nil { + if !r.CrowdsecMode.Valid() { + return AccessRestrictions{}, fmt.Errorf("invalid crowdsec_mode %q", *r.CrowdsecMode) + } + res.CrowdSecMode = string(*r.CrowdsecMode) + } + return res, nil +} + +func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && + len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 && + r.CrowdSecMode == "" { + return nil + } + res := &api.AccessRestrictions{} + if len(r.AllowedCIDRs) > 0 { + res.AllowedCidrs = &r.AllowedCIDRs + } + if len(r.BlockedCIDRs) > 0 { + res.BlockedCidrs = &r.BlockedCIDRs + } + if len(r.AllowedCountries) > 0 { + res.AllowedCountries = &r.AllowedCountries + } + if len(r.BlockedCountries) > 0 { + res.BlockedCountries = &r.BlockedCountries + } + if r.CrowdSecMode != "" { + mode := api.AccessRestrictionsCrowdsecMode(r.CrowdSecMode) + res.CrowdsecMode = &mode + } + return res +} + +func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && + len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 && + r.CrowdSecMode == "" { + return nil + } + return &proto.AccessRestrictions{ + AllowedCidrs: r.AllowedCIDRs, + BlockedCidrs: r.BlockedCIDRs, + AllowedCountries: r.AllowedCountries, + BlockedCountries: r.BlockedCountries, + CrowdsecMode: r.CrowdSecMode, + } +} + +func (s *Service) Validate() error { + if s.Name == "" { + return errors.New("service name is required") + } + if len(s.Name) > 255 { + return errors.New("service name exceeds maximum length of 255 characters") + } + + if len(s.Targets) == 0 { + return errors.New("at least one target is required") + } + + if s.Mode == "" { + s.Mode = ModeHTTP + } + + if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil { + return err + } + if err := validateAccessRestrictions(&s.Restrictions); err != nil { + return err + } + + switch s.Mode { + case ModeHTTP: + return s.validateHTTPMode() + case ModeTCP, ModeUDP: + return s.validateTCPUDPMode() + case ModeTLS: + return s.validateTLSMode() + default: + return fmt.Errorf("unsupported mode %q", s.Mode) + } +} + +func (s *Service) validateHTTPMode() error { + if s.Domain == "" { + return errors.New("service domain is required") + } + if s.ListenPort != 0 { + return errors.New("listen_port is not supported for HTTP services") + } + return s.validateHTTPTargets() +} + +func (s *Service) validateTCPUDPMode() error { + if s.Domain == "" { + return errors.New("domain is required for TCP/UDP services (used for cluster derivation)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TCP/UDP services") + } + if len(s.Targets) != 1 { + return errors.New("TCP/UDP services must have exactly one target") + } + if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol { + return errors.New("proxy_protocol is not supported for UDP services") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateTLSMode() error { + if s.Domain == "" { + return errors.New("domain is required for TLS services (used for SNI matching)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TLS services") + } + if s.ListenPort == 0 { + return errors.New("listen_port is required for TLS services") + } + if len(s.Targets) != 1 { + return errors.New("TLS services must have exactly one target") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateHTTPTargets() error { + for i, target := range s.Targets { + switch target.TargetType { + case TargetTypePeer, TargetTypeHost, TargetTypeDomain: + // host field will be ignored + case TargetTypeSubnet: + if target.Host == "" { + return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType) + } + default: + return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType) + } + if target.TargetId == "" { + return fmt.Errorf("target %d has empty target_id", i) + } + if target.ProxyProtocol { + return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i) + } + if err := validateTargetOptions(i, &target.Options); err != nil { + return err + } + } + + return nil +} + +func (s *Service) validateL4Target(target *Target) error { + // L4 services have a single target; per-target disable is meaningless + // (use the service-level Enabled flag instead). Force it on so that + // buildPathMappings always includes the target in the proto. + target.Enabled = true + + if target.Port == 0 { + return errors.New("target port is required for L4 services") + } + if target.TargetId == "" { + return errors.New("target_id is required for L4 services") + } + switch target.TargetType { + case TargetTypePeer, TargetTypeHost, TargetTypeDomain: + // OK + case TargetTypeSubnet: + if target.Host == "" { + return errors.New("target host is required for subnet targets") + } + default: + return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) + } + if target.Path != nil && *target.Path != "" && *target.Path != "/" { + return errors.New("path is not supported for L4 services") + } + if target.Options.SessionIdleTimeout < 0 { + return errors.New("session_idle_timeout must be positive for L4 services") + } + if target.Options.RequestTimeout < 0 { + return errors.New("request_timeout must be positive for L4 services") + } + if target.Options.SkipTLSVerify { + return errors.New("skip_tls_verify is not supported for L4 services") + } + if target.Options.PathRewrite != "" { + return errors.New("path_rewrite is not supported for L4 services") + } + if len(target.Options.CustomHeaders) > 0 { + return errors.New("custom_headers is not supported for L4 services") + } + return nil +} + +// Service mode constants. +const ( + ModeHTTP = "http" + ModeTCP = "tcp" + ModeUDP = "udp" + ModeTLS = "tls" +) + +// Target protocol constants (URL scheme for backend connections). +const ( + TargetProtoHTTP = "http" + TargetProtoHTTPS = "https" + TargetProtoTCP = "tcp" + TargetProtoUDP = "udp" +) + +// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS). +func IsL4Protocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS +} + +// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation. +// TLS is excluded because it uses SNI routing and can share ports with other TLS services. +func IsPortBasedProtocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP +} + +const ( + maxCustomHeaders = 16 + maxHeaderKeyLen = 128 + maxHeaderValueLen = 4096 +) + +// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. +var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`) + +// hopByHopHeaders are headers that must not be set as custom headers +// because they are connection-level and stripped by the proxy. +var hopByHopHeaders = map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Proxy-Connection": {}, + "Te": {}, + "Trailer": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, +} + +// reservedHeaders are set authoritatively by the proxy or control HTTP framing +// and cannot be overridden. +var reservedHeaders = map[string]struct{}{ + "Content-Length": {}, + "Content-Type": {}, + "Cookie": {}, + "Forwarded": {}, + "X-Forwarded-For": {}, + "X-Forwarded-Host": {}, + "X-Forwarded-Port": {}, + "X-Forwarded-Proto": {}, + "X-Real-Ip": {}, +} + +func validateTargetOptions(idx int, opts *TargetOptions) error { + if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve { + return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite) + } + + if opts.RequestTimeout < 0 { + return fmt.Errorf("target %d: request_timeout must be positive", idx) + } + + if opts.SessionIdleTimeout < 0 { + return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) + } + + if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { + return err + } + + return nil +} + +func validateCustomHeaders(idx int, headers map[string]string) error { + if len(headers) > maxCustomHeaders { + return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders) + } + seen := make(map[string]string, len(headers)) + for key, value := range headers { + if !httpHeaderNameRe.MatchString(key) { + return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key) + } + if len(key) > maxHeaderKeyLen { + return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen) + } + if len(value) > maxHeaderValueLen { + return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen) + } + if containsCRLF(key) || containsCRLF(value) { + return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key) + } + canonical := http.CanonicalHeaderKey(key) + if prev, ok := seen[canonical]; ok { + return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical) + } + seen[canonical] = key + if _, ok := hopByHopHeaders[canonical]; ok { + return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key) + } + if _, ok := reservedHeaders[canonical]; ok { + return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key) + } + if canonical == "Host" { + return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx) + } + } + return nil +} + +func containsCRLF(s string) bool { + return strings.ContainsAny(s, "\r\n") +} + +func validateHeaderAuths(headers []*HeaderAuthConfig) error { + for i, h := range headers { + if h == nil || !h.Enabled { + continue + } + if h.Header == "" { + return fmt.Errorf("header_auths[%d]: header name is required", i) + } + if !httpHeaderNameRe.MatchString(h.Header) { + return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header) + } + canonical := http.CanonicalHeaderKey(h.Header) + if _, ok := hopByHopHeaders[canonical]; ok { + return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header) + } + if _, ok := reservedHeaders[canonical]; ok { + return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header) + } + if canonical == "Host" { + return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i) + } + if len(h.Value) > maxHeaderValueLen { + return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen) + } + } + return nil +} + +const ( + maxCIDREntries = 200 + maxCountryEntries = 50 +) + +// validateAccessRestrictions validates and normalizes access restriction +// entries. Country codes are uppercased in place. +func validateCrowdSecMode(mode string) error { + switch mode { + case "", "off", "enforce", "observe": + return nil + default: + return fmt.Errorf("crowdsec_mode %q is invalid", mode) + } +} + +func validateAccessRestrictions(r *AccessRestrictions) error { + if err := validateCrowdSecMode(r.CrowdSecMode); err != nil { + return err + } + + if len(r.AllowedCIDRs) > maxCIDREntries { + return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries) + } + if len(r.BlockedCIDRs) > maxCIDREntries { + return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries) + } + if len(r.AllowedCountries) > maxCountryEntries { + return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries) + } + if len(r.BlockedCountries) > maxCountryEntries { + return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries) + } + + if err := validateCIDRList("allowed_cidrs", r.AllowedCIDRs); err != nil { + return err + } + if err := validateCIDRList("blocked_cidrs", r.BlockedCIDRs); err != nil { + return err + } + if err := normalizeCountryList("allowed_countries", r.AllowedCountries); err != nil { + return err + } + return normalizeCountryList("blocked_countries", r.BlockedCountries) +} + +func validateCIDRList(field string, cidrs []string) error { + for i, raw := range cidrs { + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return fmt.Errorf("%s[%d]: %w", field, i, err) + } + if prefix != prefix.Masked() { + return fmt.Errorf("%s[%d]: %q has host bits set, use %s instead", field, i, raw, prefix.Masked()) + } + } + return nil +} + +func normalizeCountryList(field string, codes []string) error { + for i, code := range codes { + if len(code) != 2 { + return fmt.Errorf("%s[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", field, i, code) + } + codes[i] = strings.ToUpper(code) + } + return nil +} + +func (s *Service) EventMeta() map[string]any { + meta := map[string]any{ + "name": s.Name, + "domain": s.Domain, + "proxy_cluster": s.ProxyCluster, + "source": s.Source, + "auth": s.isAuthEnabled(), + "mode": s.Mode, + } + + if s.ListenPort != 0 { + meta["listen_port"] = s.ListenPort + } + + if len(s.Targets) > 0 { + t := s.Targets[0] + if t.ProxyProtocol { + meta["proxy_protocol"] = true + } + if t.Options.RequestTimeout != 0 { + meta["request_timeout"] = t.Options.RequestTimeout.String() + } + if t.Options.SessionIdleTimeout != 0 { + meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String() + } + } + + return meta +} + +func (s *Service) isAuthEnabled() bool { + if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || + (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || + (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) { + return true + } + for _, h := range s.Auth.HeaderAuths { + if h != nil && h.Enabled { + return true + } + } + return false +} + +func (s *Service) Copy() *Service { + targets := make([]*Target, len(s.Targets)) + for i, target := range s.Targets { + targetCopy := *target + if target.Path != nil { + p := *target.Path + targetCopy.Path = &p + } + if len(target.Options.CustomHeaders) > 0 { + targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders)) + for k, v := range target.Options.CustomHeaders { + targetCopy.Options.CustomHeaders[k] = v + } + } + targets[i] = &targetCopy + } + + authCopy := s.Auth + if s.Auth.PasswordAuth != nil { + pa := *s.Auth.PasswordAuth + authCopy.PasswordAuth = &pa + } + if s.Auth.PinAuth != nil { + pa := *s.Auth.PinAuth + authCopy.PinAuth = &pa + } + if s.Auth.BearerAuth != nil { + ba := *s.Auth.BearerAuth + if len(s.Auth.BearerAuth.DistributionGroups) > 0 { + ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups)) + copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups) + } + authCopy.BearerAuth = &ba + } + if len(s.Auth.HeaderAuths) > 0 { + authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths)) + for i, h := range s.Auth.HeaderAuths { + if h == nil { + continue + } + hCopy := *h + authCopy.HeaderAuths[i] = &hCopy + } + } + + return &Service{ + ID: s.ID, + AccountID: s.AccountID, + Name: s.Name, + Domain: s.Domain, + ProxyCluster: s.ProxyCluster, + Targets: targets, + Enabled: s.Enabled, + Terminated: s.Terminated, + PassHostHeader: s.PassHostHeader, + RewriteRedirects: s.RewriteRedirects, + Auth: authCopy, + Restrictions: s.Restrictions.Copy(), + Meta: s.Meta, + SessionPrivateKey: s.SessionPrivateKey, + SessionPublicKey: s.SessionPublicKey, + Source: s.Source, + SourcePeer: s.SourcePeer, + Mode: s.Mode, + ListenPort: s.ListenPort, + PortAutoAssigned: s.PortAutoAssigned, + } +} + +func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + if s.SessionPrivateKey != "" { + var err error + s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey) + if err != nil { + return err + } + } + + return nil +} + +func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + if s.SessionPrivateKey != "" { + var err error + s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey) + if err != nil { + return err + } + } + + return nil +} + +var pinRegexp = regexp.MustCompile(`^\d{6}$`) + +const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789" + +var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`) + +// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service. +type ExposeServiceRequest struct { + NamePrefix string + Port uint16 + Mode string + // TargetProtocol is the protocol used to connect to the peer backend. + // For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp". + TargetProtocol string + Domain string + Pin string + Password string + UserGroups []string + ListenPort uint16 +} + +// Validate checks all fields of the expose request. +func (r *ExposeServiceRequest) Validate() error { + if r == nil { + return errors.New("request cannot be nil") + } + + if r.Port == 0 { + return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port) + } + + switch r.Mode { + case ModeHTTP, ModeTCP, ModeUDP, ModeTLS: + default: + return fmt.Errorf("unsupported mode %q", r.Mode) + } + + if IsL4Protocol(r.Mode) { + if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 { + return fmt.Errorf("authentication is not supported for %s mode", r.Mode) + } + } + + if r.Pin != "" && !pinRegexp.MatchString(r.Pin) { + return errors.New("invalid pin: must be exactly 6 digits") + } + + for _, g := range r.UserGroups { + if g == "" { + return errors.New("user group name cannot be empty") + } + } + + if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) { + return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix) + } + + return nil +} + +// ToService builds a Service from the expose request. +func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service { + svc := &Service{ + AccountID: accountID, + Name: serviceName, + Mode: r.Mode, + Enabled: true, + } + + // If domain is empty, CreateServiceFromPeer generates a unique subdomain. + // When explicitly provided, the service name is prepended as a subdomain. + if r.Domain != "" { + svc.Domain = serviceName + "." + r.Domain + } + + if IsL4Protocol(r.Mode) { + svc.ListenPort = r.Port + if r.ListenPort > 0 { + svc.ListenPort = r.ListenPort + } + } + + var targetProto string + switch { + case !IsL4Protocol(r.Mode): + targetProto = TargetProtoHTTP + if r.TargetProtocol != "" { + targetProto = r.TargetProtocol + } + case r.Mode == ModeUDP: + targetProto = TargetProtoUDP + default: + targetProto = TargetProtoTCP + } + svc.Targets = []*Target{ + { + AccountID: accountID, + Port: r.Port, + Protocol: targetProto, + TargetId: peerID, + TargetType: TargetTypePeer, + Enabled: true, + }, + } + + if r.Pin != "" { + svc.Auth.PinAuth = &PINAuthConfig{ + Enabled: true, + Pin: r.Pin, + } + } + + if r.Password != "" { + svc.Auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: true, + Password: r.Password, + } + } + + if len(r.UserGroups) > 0 { + svc.Auth.BearerAuth = &BearerAuthConfig{ + Enabled: true, + DistributionGroups: r.UserGroups, + } + } + + return svc +} + +// ExposeServiceResponse contains the result of a successful peer expose creation. +type ExposeServiceResponse struct { + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool +} + +// GenerateExposeName generates a random service name for peer-exposed services. +// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens). +func GenerateExposeName(prefix string) (string, error) { + if prefix != "" && !validNamePrefix.MatchString(prefix) { + return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix) + } + + suffixLen := 12 + if prefix != "" { + suffixLen = 4 + } + + suffix, err := randomAlphanumeric(suffixLen) + if err != nil { + return "", fmt.Errorf("generate random name: %w", err) + } + + if prefix == "" { + return suffix, nil + } + return prefix + "-" + suffix, nil +} + +func randomAlphanumeric(n int) (string, error) { + result := make([]byte, n) + charsetLen := big.NewInt(int64(len(alphanumCharset))) + for i := range result { + idx, err := rand.Int(rand.Reader, charsetLen) + if err != nil { + return "", err + } + result[i] = alphanumCharset[idx.Int64()] + } + return string(result), nil +} diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go new file mode 100644 index 000000000..ff54cb79f --- /dev/null +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -0,0 +1,1041 @@ +package service + +import ( + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/shared/hash/argon2id" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func validProxy() *Service { + return &Service{ + Name: "test", + Domain: "example.com", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true}, + }, + } +} + +func TestValidate_Valid(t *testing.T) { + require.NoError(t, validProxy().Validate()) +} + +func TestValidate_EmptyName(t *testing.T) { + rp := validProxy() + rp.Name = "" + assert.ErrorContains(t, rp.Validate(), "name is required") +} + +func TestValidate_EmptyDomain(t *testing.T) { + rp := validProxy() + rp.Domain = "" + assert.ErrorContains(t, rp.Validate(), "domain is required") +} + +func TestValidate_NoTargets(t *testing.T) { + rp := validProxy() + rp.Targets = nil + assert.ErrorContains(t, rp.Validate(), "at least one target is required") +} + +func TestValidate_EmptyTargetId(t *testing.T) { + rp := validProxy() + rp.Targets[0].TargetId = "" + assert.ErrorContains(t, rp.Validate(), "empty target_id") +} + +func TestValidate_InvalidTargetType(t *testing.T) { + rp := validProxy() + rp.Targets[0].TargetType = "invalid" + assert.ErrorContains(t, rp.Validate(), "invalid target_type") +} + +func TestValidate_ResourceTarget(t *testing.T) { + rp := validProxy() + rp.Targets = append(rp.Targets, &Target{ + TargetId: "resource-1", + TargetType: TargetTypeHost, + Host: "example.org", + Port: 443, + Protocol: "https", + Enabled: true, + }) + require.NoError(t, rp.Validate()) +} + +func TestValidate_MultipleTargetsOneInvalid(t *testing.T) { + rp := validProxy() + rp.Targets = append(rp.Targets, &Target{ + TargetId: "", + TargetType: TargetTypePeer, + Host: "10.0.0.2", + Port: 80, + Protocol: "http", + Enabled: true, + }) + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "target 1") + assert.Contains(t, err.Error(), "empty target_id") +} + +func TestValidateTargetOptions_PathRewrite(t *testing.T) { + tests := []struct { + name string + mode PathRewriteMode + wantErr string + }{ + {"empty is default", "", ""}, + {"preserve is valid", PathRewritePreserve, ""}, + {"unknown rejected", "regex", "unknown path_rewrite mode"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.PathRewrite = tt.mode + err := rp.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tt.wantErr) + } + }) + } +} + +func TestValidateTargetOptions_RequestTimeout(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantErr string + }{ + {"valid 30s", 30 * time.Second, ""}, + {"valid 2m", 2 * time.Minute, ""}, + {"valid 10m", 10 * time.Minute, ""}, + {"zero is fine", 0, ""}, + {"negative", -1 * time.Second, "must be positive"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.RequestTimeout = tt.timeout + err := rp.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tt.wantErr) + } + }) + } +} + +func TestValidateTargetOptions_CustomHeaders(t *testing.T) { + t.Run("valid headers", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{ + "X-Custom": "value", + "X-Trace": "abc123", + } + assert.NoError(t, rp.Validate()) + }) + + t.Run("CRLF in key", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"} + assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name") + }) + + t.Run("CRLF in value", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"} + assert.ErrorContains(t, rp.Validate(), "invalid characters") + }) + + t.Run("hop-by-hop header rejected", func(t *testing.T) { + for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"} + assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h) + } + }) + + t.Run("reserved header rejected", func(t *testing.T) { + for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"} + assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h) + } + }) + + t.Run("Host header rejected", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"} + assert.ErrorContains(t, rp.Validate(), "pass_host_header") + }) + + t.Run("too many headers", func(t *testing.T) { + rp := validProxy() + headers := make(map[string]string, 17) + for i := range 17 { + headers[fmt.Sprintf("X-H%d", i)] = "v" + } + rp.Targets[0].Options.CustomHeaders = headers + assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16") + }) + + t.Run("key too long", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"} + assert.ErrorContains(t, rp.Validate(), "key") + assert.ErrorContains(t, rp.Validate(), "exceeds maximum length") + }) + + t.Run("value too long", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)} + assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length") + }) + + t.Run("duplicate canonical keys rejected", func(t *testing.T) { + rp := validProxy() + rp.Targets[0].Options.CustomHeaders = map[string]string{ + "x-custom": "a", + "X-Custom": "b", + } + assert.ErrorContains(t, rp.Validate(), "collide") + }) +} + +func TestToProtoMapping_TargetOptions(t *testing.T) { + rp := &Service{ + ID: "svc-1", + AccountID: "acc-1", + Domain: "example.com", + Targets: []*Target{ + { + TargetId: "peer-1", + TargetType: TargetTypePeer, + Host: "10.0.0.1", + Port: 8080, + Protocol: "http", + Enabled: true, + Options: TargetOptions{ + SkipTLSVerify: true, + RequestTimeout: 30 * time.Second, + PathRewrite: PathRewritePreserve, + CustomHeaders: map[string]string{"X-Custom": "val"}, + }, + }, + }, + } + pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{}) + require.Len(t, pm.Path, 1) + + opts := pm.Path[0].Options + require.NotNil(t, opts, "options should be populated") + assert.True(t, opts.SkipTlsVerify) + assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite) + assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders) + require.NotNil(t, opts.RequestTimeout) + assert.Equal(t, int64(30), opts.RequestTimeout.Seconds) +} + +func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) { + rp := &Service{ + ID: "svc-1", + AccountID: "acc-1", + Domain: "example.com", + Targets: []*Target{ + { + TargetId: "peer-1", + TargetType: TargetTypePeer, + Host: "10.0.0.1", + Port: 8080, + Protocol: "http", + Enabled: true, + }, + }, + } + pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{}) + require.Len(t, pm.Path, 1) + assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults") +} + +func TestIsDefaultPort(t *testing.T) { + tests := []struct { + scheme string + port uint16 + want bool + }{ + {"http", 80, true}, + {"https", 443, true}, + {"http", 443, false}, + {"https", 80, false}, + {"http", 8080, false}, + {"https", 8443, false}, + {"http", 0, false}, + {"https", 0, false}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) { + assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port)) + }) + } +} + +func TestToProtoMapping_PortInTargetURL(t *testing.T) { + oidcConfig := proxy.OIDCValidationConfig{} + + tests := []struct { + name string + protocol string + host string + port uint16 + wantTarget string + }{ + { + name: "http with default port 80 omits port", + protocol: "http", + host: "10.0.0.1", + port: 80, + wantTarget: "http://10.0.0.1/", + }, + { + name: "https with default port 443 omits port", + protocol: "https", + host: "10.0.0.1", + port: 443, + wantTarget: "https://10.0.0.1/", + }, + { + name: "port 0 omits port", + protocol: "http", + host: "10.0.0.1", + port: 0, + wantTarget: "http://10.0.0.1/", + }, + { + name: "non-default port is included", + protocol: "http", + host: "10.0.0.1", + port: 8080, + wantTarget: "http://10.0.0.1:8080/", + }, + { + name: "https with non-default port is included", + protocol: "https", + host: "10.0.0.1", + port: 8443, + wantTarget: "https://10.0.0.1:8443/", + }, + { + name: "http port 443 is included", + protocol: "http", + host: "10.0.0.1", + port: 443, + wantTarget: "http://10.0.0.1:443/", + }, + { + name: "https port 80 is included", + protocol: "https", + host: "10.0.0.1", + port: 80, + wantTarget: "https://10.0.0.1:80/", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := &Service{ + ID: "test-id", + AccountID: "acc-1", + Domain: "example.com", + Targets: []*Target{ + { + TargetId: "peer-1", + TargetType: TargetTypePeer, + Host: tt.host, + Port: tt.port, + Protocol: tt.protocol, + Enabled: true, + }, + }, + } + pm := rp.ToProtoMapping(Create, "token", oidcConfig) + require.Len(t, pm.Path, 1, "should have one path mapping") + assert.Equal(t, tt.wantTarget, pm.Path[0].Target) + }) + } +} + +func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) { + rp := &Service{ + ID: "test-id", + AccountID: "acc-1", + Domain: "example.com", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false}, + {TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true}, + }, + } + pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{}) + require.Len(t, pm.Path, 1) + assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target) +} + +func TestToProtoMapping_OperationTypes(t *testing.T) { + rp := validProxy() + tests := []struct { + op Operation + want proto.ProxyMappingUpdateType + }{ + {Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED}, + {Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED}, + {Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED}, + } + for _, tt := range tests { + t.Run(string(tt.op), func(t *testing.T) { + pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{}) + assert.Equal(t, tt.want, pm.Type) + }) + } +} + +func TestAuthConfig_HashSecrets(t *testing.T) { + tests := []struct { + name string + config *AuthConfig + wantErr bool + validate func(*testing.T, *AuthConfig) + }{ + { + name: "hash password successfully", + config: &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: true, + Password: "testPassword123", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") { + t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password) + } + // Verify the hash can be verified + if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil { + t.Errorf("Hash verification failed: %v", err) + } + }, + }, + { + name: "hash PIN successfully", + config: &AuthConfig{ + PinAuth: &PINAuthConfig{ + Enabled: true, + Pin: "123456", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { + t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin) + } + // Verify the hash can be verified + if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil { + t.Errorf("Hash verification failed: %v", err) + } + }, + }, + { + name: "hash both password and PIN", + config: &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: true, + Password: "password", + }, + PinAuth: &PINAuthConfig{ + Enabled: true, + Pin: "9999", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") { + t.Errorf("Password not hashed with argon2id") + } + if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { + t.Errorf("PIN not hashed with argon2id") + } + if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil { + t.Errorf("Password hash verification failed: %v", err) + } + if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil { + t.Errorf("PIN hash verification failed: %v", err) + } + }, + }, + { + name: "skip disabled password auth", + config: &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: false, + Password: "password", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if config.PasswordAuth.Password != "password" { + t.Errorf("Disabled password auth should not be hashed") + } + }, + }, + { + name: "skip empty password", + config: &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: true, + Password: "", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if config.PasswordAuth.Password != "" { + t.Errorf("Empty password should remain empty") + } + }, + }, + { + name: "skip nil password auth", + config: &AuthConfig{ + PasswordAuth: nil, + PinAuth: &PINAuthConfig{ + Enabled: true, + Pin: "1234", + }, + }, + wantErr: false, + validate: func(t *testing.T, config *AuthConfig) { + if config.PasswordAuth != nil { + t.Errorf("PasswordAuth should remain nil") + } + if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") { + t.Errorf("PIN should still be hashed") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.HashSecrets() + if (err != nil) != tt.wantErr { + t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.validate != nil { + tt.validate(t, tt.config) + } + }) + } +} + +func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) { + config := &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: true, + Password: "correctPassword", + }, + } + + if err := config.HashSecrets(); err != nil { + t.Fatalf("HashSecrets() error = %v", err) + } + + // Verify with wrong password should fail + err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password) + if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) { + t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err) + } +} + +func TestAuthConfig_ClearSecrets(t *testing.T) { + config := &AuthConfig{ + PasswordAuth: &PasswordAuthConfig{ + Enabled: true, + Password: "hashedPassword", + }, + PinAuth: &PINAuthConfig{ + Enabled: true, + Pin: "hashedPin", + }, + } + + config.ClearSecrets() + + if config.PasswordAuth.Password != "" { + t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password) + } + if config.PinAuth.Pin != "" { + t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin) + } +} + +func TestGenerateExposeName(t *testing.T) { + t.Run("no prefix generates 12-char name", func(t *testing.T) { + name, err := GenerateExposeName("") + require.NoError(t, err) + assert.Len(t, name, 12) + assert.Regexp(t, `^[a-z0-9]+$`, name) + }) + + t.Run("with prefix generates prefix-XXXX", func(t *testing.T) { + name, err := GenerateExposeName("myapp") + require.NoError(t, err) + assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix") + suffix := strings.TrimPrefix(name, "myapp-") + assert.Len(t, suffix, 4, "suffix should be 4 chars") + assert.Regexp(t, `^[a-z0-9]+$`, suffix) + }) + + t.Run("unique names", func(t *testing.T) { + names := make(map[string]bool) + for i := 0; i < 50; i++ { + name, err := GenerateExposeName("") + require.NoError(t, err) + names[name] = true + } + assert.Greater(t, len(names), 45, "should generate mostly unique names") + }) + + t.Run("valid prefixes", func(t *testing.T) { + validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"} + for _, prefix := range validPrefixes { + name, err := GenerateExposeName(prefix) + assert.NoError(t, err, "prefix %q should be valid", prefix) + assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix) + } + }) + + t.Run("invalid prefixes", func(t *testing.T) { + invalidPrefixes := []string{ + "-starts-with-dash", + "ends-with-dash-", + "has.dots", + "HAS-UPPER", + "has spaces", + "has/slash", + "a--", + } + for _, prefix := range invalidPrefixes { + _, err := GenerateExposeName(prefix) + assert.Error(t, err, "prefix %q should be invalid", prefix) + assert.Contains(t, err.Error(), "invalid name prefix") + } + }) +} + +func TestExposeServiceRequest_ToService(t *testing.T) { + t.Run("basic HTTP service", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 8080, + Mode: "http", + } + + service := req.ToService("account-1", "peer-1", "mysvc") + + assert.Equal(t, "account-1", service.AccountID) + assert.Equal(t, "mysvc", service.Name) + assert.True(t, service.Enabled) + assert.Empty(t, service.Domain, "domain should be empty when not specified") + require.Len(t, service.Targets, 1) + + target := service.Targets[0] + assert.Equal(t, uint16(8080), target.Port) + assert.Equal(t, "http", target.Protocol) + assert.Equal(t, "peer-1", target.TargetId) + assert.Equal(t, TargetTypePeer, target.TargetType) + assert.True(t, target.Enabled) + assert.Equal(t, "account-1", target.AccountID) + }) + + t.Run("with custom domain", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 3000, + Domain: "example.com", + } + + service := req.ToService("acc", "peer", "web") + assert.Equal(t, "web.example.com", service.Domain) + }) + + t.Run("with PIN auth", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 80, + Pin: "1234", + } + + service := req.ToService("acc", "peer", "svc") + require.NotNil(t, service.Auth.PinAuth) + assert.True(t, service.Auth.PinAuth.Enabled) + assert.Equal(t, "1234", service.Auth.PinAuth.Pin) + assert.Nil(t, service.Auth.PasswordAuth) + assert.Nil(t, service.Auth.BearerAuth) + }) + + t.Run("with password auth", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 80, + Password: "secret", + } + + service := req.ToService("acc", "peer", "svc") + require.NotNil(t, service.Auth.PasswordAuth) + assert.True(t, service.Auth.PasswordAuth.Enabled) + assert.Equal(t, "secret", service.Auth.PasswordAuth.Password) + }) + + t.Run("with user groups (bearer auth)", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 80, + UserGroups: []string{"admins", "devs"}, + } + + service := req.ToService("acc", "peer", "svc") + require.NotNil(t, service.Auth.BearerAuth) + assert.True(t, service.Auth.BearerAuth.Enabled) + assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups) + }) + + t.Run("with all auth types", func(t *testing.T) { + req := &ExposeServiceRequest{ + Port: 443, + Domain: "myco.com", + Pin: "9999", + Password: "pass", + UserGroups: []string{"ops"}, + } + + service := req.ToService("acc", "peer", "full") + assert.Equal(t, "full.myco.com", service.Domain) + require.NotNil(t, service.Auth.PinAuth) + require.NotNil(t, service.Auth.PasswordAuth) + require.NotNil(t, service.Auth.BearerAuth) + }) +} + +func TestValidate_TLSOnly(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TLSMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 0, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "listen_port is required") +} + +func TestValidate_TLSMissingDomain(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "domain is required") +} + +func TestValidate_TCPValid(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TCPMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)") +} + +func TestValidate_L4MultipleTargets(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + {TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "exactly one target") +} + +func TestValidate_L4TargetMissingPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "port is required") +} + +func TestValidate_TLSInvalidTargetType(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.Error(t, rp.Validate()) +} + +func TestValidate_TLSSubnetValid(t *testing.T) { + rp := &Service{ + Name: "tls-subnet", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_L4DomainTargetValid(t *testing.T) { + modes := []struct { + mode string + port uint16 + proto string + }{ + {"tcp", 5432, "tcp"}, + {"tls", 443, "tcp"}, + {"udp", 5432, "udp"}, + } + for _, m := range modes { + t.Run(m.mode, func(t *testing.T) { + rp := &Service{ + Name: m.mode + "-domain", + Mode: m.mode, + Domain: "cluster.test", + ListenPort: m.port, + Targets: []*Target{ + {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) + }) + } +} + +func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { + rp := validProxy() + rp.Targets[0].ProxyProtocol = true + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP") +} + +func TestValidate_UDPProxyProtocolRejected(t *testing.T) { + rp := &Service{ + Name: "udp-svc", + Mode: "udp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP") +} + +func TestValidate_TCPProxyProtocolAllowed(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) { + tests := []struct { + name string + req ExposeServiceRequest + }{ + { + name: "tcp with pin", + req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"}, + }, + { + name: "udp with password", + req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"}, + }, + { + name: "tls with user groups", + req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") + }) + } +} + +func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) { + req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"} + require.NoError(t, req.Validate()) +} + +func TestValidate_HeaderAuths(t *testing.T) { + t.Run("single valid header", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "X-API-Key", Value: "secret"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple headers same canonical name allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Authorization", Value: "Bearer token-1"}, + {Enabled: true, Header: "Authorization", Value: "Bearer token-2"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple headers different case same canonical allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "x-api-key", Value: "key-1"}, + {Enabled: true, Header: "X-Api-Key", Value: "key-2"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple different headers allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Authorization", Value: "Bearer tok"}, + {Enabled: true, Header: "X-API-Key", Value: "key"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("empty header name rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "header name is required") + }) + + t.Run("hop-by-hop header rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Connection", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "hop-by-hop") + }) + + t.Run("host header rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Host", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "Host header cannot be used") + }) + + t.Run("disabled entries skipped", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: false, Header: "", Value: ""}, + {Enabled: true, Header: "X-Key", Value: "val"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("value too long rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum length") + }) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index e897a09f5..24dfb641b 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -18,13 +18,16 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/store" @@ -57,6 +60,18 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { }) } +// CacheStore returns a shared cache store backed by Redis or in-memory depending on the environment. +// All consumers should reuse this store to avoid creating multiple Redis connections. +func (s *BaseServer) CacheStore() cachestore.StoreInterface { + return Create(s, func() cachestore.StoreInterface { + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultStoreMaxTimeout, nbcache.DefaultStoreCleanupInterval, nbcache.DefaultStoreMaxConn) + if err != nil { + log.Fatalf("failed to create shared cache store: %v", err) + } + return cs + }) +} + func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) @@ -94,7 +109,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -134,7 +149,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { if s.Config.HttpConfig.LetsEncryptDomain != "" { certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { - log.Fatalf("failed to create certificate manager: %v", err) + log.Fatalf("failed to create certificate service: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) @@ -152,6 +167,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server { if err != nil { log.Fatalf("failed to create management server: %v", err) } + serviceMgr := s.ServiceManager() + srv.SetReverseProxyManager(serviceMgr) + if serviceMgr != nil { + serviceMgr.StartExposeReaper(context.Background()) + } mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer()) @@ -163,9 +183,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) s.AfterInit(func(s *BaseServer) { - proxyService.SetProxyManager(s.ReverseProxyManager()) + proxyService.SetServiceManager(s.ServiceManager()) + proxyService.SetProxyController(s.ServiceProxyController()) }) return proxyService }) @@ -188,12 +209,18 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), s.CacheStore()) log.Info("One-time token store initialized for proxy authentication") return tokenStore }) } +func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { + return Create(s, func() *nbgrpc.PKCEVerifierStore { + return nbgrpc.NewPKCEVerifierStore(context.Background(), s.CacheStore()) + }) +} + func (s *BaseServer) AccessLogsManager() accesslogs.Manager { return Create(s, func() accesslogs.Manager { accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) diff --git a/management/internals/server/config/config.go b/management/internals/server/config/config.go index 0ba393263..fb9c842b7 100644 --- a/management/internals/server/config/config.go +++ b/management/internals/server/config/config.go @@ -203,7 +203,7 @@ type ReverseProxy struct { // AccessLogRetentionDays specifies the number of days to retain access logs. // Logs older than this duration will be automatically deleted during cleanup. - // A value of 0 or negative means logs are kept indefinitely (no cleanup). + // A value of 0 will default to 7 days. Negative means logs are kept indefinitely (no cleanup). AccessLogRetentionDays int // AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine. diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 4ea86900a..9a8e45d33 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -18,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { @@ -38,7 +41,8 @@ func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValida context.Background(), s.PeersManager(), s.SettingsManager(), - s.EventStore()) + s.EventStore(), + s.CacheStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } @@ -69,6 +73,7 @@ func (s *BaseServer) AuthManager() auth.Manager { signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled issuer := s.Config.HttpConfig.AuthIssuer userIDClaim := s.Config.HttpConfig.AuthUserIDClaim + var keyFetcher nbjwt.KeyFetcher // Use embedded IdP configuration if available if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil { @@ -76,8 +81,11 @@ func (s *BaseServer) AuthManager() auth.Manager { if len(audiences) > 0 { audience = audiences[0] // Use the first client ID as the primary audience } - // Use localhost keys location for internal validation (management has embedded Dex) - keysLocation = oauthProvider.GetLocalKeysLocation() + keyFetcher = oauthProvider.GetKeyFetcher() + // Fall back to default keys location if direct key fetching is not available + if keyFetcher == nil { + keysLocation = oauthProvider.GetLocalKeysLocation() + } signingKeyRefreshEnabled = true issuer = oauthProvider.GetIssuer() userIDClaim = oauthProvider.GetUserIDClaim() @@ -90,7 +98,8 @@ func (s *BaseServer) AuthManager() auth.Manager { keysLocation, userIDClaim, audiences, - signingKeyRefreshEnabled) + signingKeyRefreshEnabled, + keyFetcher) }) } @@ -106,6 +115,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller { }) } +func (s *BaseServer) ServiceProxyController() proxy.Controller { + return Create(s, func() proxy.Controller { + controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter()) + if err != nil { + log.Fatalf("failed to create service proxy controller: %v", err) + } + return controller + }) +} + func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 58125c0a3..9b2ec2989 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -7,10 +7,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" - nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" "github.com/netbirdio/netbird/management/internals/modules/zones/records" @@ -97,13 +100,13 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy, s.CacheStore()) if err != nil { - log.Fatalf("failed to create account manager: %v", err) + log.Fatalf("failed to create account service: %v", err) } s.AfterInit(func(s *BaseServer) { - accountManager.SetServiceManager(s.ReverseProxyManager()) + accountManager.SetServiceManager(s.ServiceManager()) }) return accountManager @@ -114,28 +117,30 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - // Use embedded IdP manager if embedded Dex is configured and enabled. + + // Use embedded IdP service if embedded Dex is configured and enabled. // Legacy IdpManager won't be used anymore even if configured. - if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled + if embeddedEnabled { idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) if err != nil { - log.Fatalf("failed to create embedded IDP manager: %v", err) + log.Fatalf("failed to create embedded IDP service: %v", err) } return idpManager } - // Fall back to external IdP manager + // Fall back to external IdP service if s.Config.IdpManagerConfig != nil { idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { - log.Fatalf("failed to create IDP manager: %v", err) + log.Fatalf("failed to create IDP service: %v", err) } } return idpManager }) } -// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil +// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { return nil @@ -162,7 +167,7 @@ func (s *BaseServer) GroupsManager() groups.Manager { func (s *BaseServer) ResourcesManager() resources.Manager { return Create(s, func() resources.Manager { - return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager()) + return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) }) } @@ -190,15 +195,25 @@ func (s *BaseServer) RecordsManager() records.Manager { }) } -func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { - return Create(s, func() reverseproxy.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) +func (s *BaseServer) ServiceManager() service.Manager { + return Create(s, func() service.Manager { + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) + }) +} + +func (s *BaseServer) ProxyManager() proxy.Manager { + return Create(s, func() proxy.Manager { + manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter()) + if err != nil { + log.Fatalf("failed to create proxy manager: %v", err) + } + return manager }) } func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { - m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager()) + m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) return &m }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 3f7f9c4c0..9b8716da1 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -28,9 +28,13 @@ import ( "github.com/netbirdio/netbird/version" ) -// ManagementLegacyPort is the port that was used before by the Management gRPC server. -// It is used for backward compatibility now. -const ManagementLegacyPort = 33073 +const ( + // ManagementLegacyPort is the port that was used before by the Management gRPC server. + // It is used for backward compatibility now. + ManagementLegacyPort = 33073 + // DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs. + DefaultSelfHostedDomain = "netbird.selfhosted" +) type Server interface { Start(ctx context.Context) error @@ -58,6 +62,7 @@ type BaseServer struct { mgmtMetricsPort int mgmtPort int disableLegacyManagementPort bool + autoResolveDomains bool proxyAuthClose func() @@ -81,6 +86,7 @@ type Config struct { DisableMetrics bool DisableGeoliteUpdate bool UserDeleteFromIDPEnabled bool + AutoResolveDomains bool } // NewServer initializes and configures a new Server instance @@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer { mgmtPort: cfg.MgmtPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort, mgmtMetricsPort: cfg.MgmtMetricsPort, + autoResolveDomains: cfg.AutoResolveDomains, } } @@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error { s.cancel = cancel s.errCh = make(chan error, 4) + if s.autoResolveDomains { + s.resolveDomains(srvCtx) + } + s.PeersManager() s.GeoLocationManager() @@ -157,7 +168,7 @@ func (s *BaseServer) Start(ctx context.Context) error { // Eagerly create the gRPC server so that all AfterInit hooks are registered // before we iterate them. Lazy creation after the loop would miss hooks - // registered during GRPCServer() construction (e.g., SetProxyManager). + // registered during GRPCServer() construction (e.g., SetServiceManager). s.GRPCServer() for _, fn := range s.afterInit { @@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error { _ = s.certManager.Listener().Close() } s.GRPCServer().Stop() - s.ReverseProxyGRPCServer().Close() if s.proxyAuthClose != nil { s.proxyAuthClose() s.proxyAuthClose = nil @@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene }() } +// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state. +// Fresh installs use the default self-hosted domain, while existing installs reuse the +// persisted account domain to keep addressing stable across config changes. +func (s *BaseServer) resolveDomains(ctx context.Context) { + st := s.Store() + + setDefault := func(logMsg string, args ...any) { + if logMsg != "" { + log.WithContext(ctx).Warnf(logMsg, args...) + } + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + } + + accountsCount, err := st.GetAccountsCounter(ctx) + if err != nil { + setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountsCount == 0 { + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain) + return + } + + accountID, err := st.GetAnyAccountID(ctx) + if err != nil { + setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountID == "" { + setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain) + return + } + + domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) + if err != nil { + setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain) + return + } + + if domain == "" { + setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain) + return + } + + s.dnsDomain = domain + s.mgmtSingleAccModeDomain = domain + log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain) +} + func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { diff --git a/management/internals/server/server_resolve_domains_test.go b/management/internals/server/server_resolve_domains_test.go new file mode 100644 index 000000000..db1d7e8ca --- /dev/null +++ b/management/internals/server/server_resolve_domains_test.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/store" +) + +func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil) + mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil) + mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, "vpn.mycompany.com", srv.dnsDomain) + require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed")) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +} diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index c74fa2660..ef417d3cf 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -107,7 +107,8 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, AutoUpdate: &proto.AutoUpdateSettings{ - Version: settings.AutoUpdateVersion, + Version: settings.AutoUpdateVersion, + AlwaysUpdate: settings.AutoUpdateAlways, }, } } diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go new file mode 100644 index 000000000..1b87f7ede --- /dev/null +++ b/management/internals/shared/grpc/expose_service.go @@ -0,0 +1,251 @@ +package grpc + +import ( + "context" + "fmt" + + pb "github.com/golang/protobuf/proto" // nolint + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbContext "github.com/netbirdio/netbird/management/server/context" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" + internalStatus "github.com/netbirdio/netbird/shared/management/status" +) + +// CreateExpose handles a peer request to create a new expose service. +func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + exposeReq := &proto.ExposeServiceRequest{} + peerKey, err := s.parseRequest(ctx, req, exposeReq) + if err != nil { + return nil, err + } + + accountID, peer, err := s.authenticateExposePeer(ctx, peerKey) + if err != nil { + return nil, err + } + + // nolint:staticcheck + ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + + reverseProxyMgr := s.getReverseProxyManager() + if reverseProxyMgr == nil { + return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") + } + + if exposeReq.Port > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port) + } + if exposeReq.ListenPort > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort) + } + + mode, err := exposeProtocolToString(exposeReq.Protocol) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{ + NamePrefix: exposeReq.NamePrefix, + Port: uint16(exposeReq.Port), //nolint:gosec // validated above + Mode: mode, + TargetProtocol: exposeTargetProtocol(exposeReq.Protocol), + Domain: exposeReq.Domain, + Pin: exposeReq.Pin, + Password: exposeReq.Password, + UserGroups: exposeReq.UserGroups, + ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above + }) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{ + ServiceName: created.ServiceName, + ServiceUrl: created.ServiceURL, + Domain: created.Domain, + PortAutoAssigned: created.PortAutoAssigned, + }) +} + +// RenewExpose extends the TTL of an active expose session. +func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + renewReq := &proto.RenewExposeRequest{} + peerKey, err := s.parseRequest(ctx, req, renewReq) + if err != nil { + return nil, err + } + + accountID, peer, err := s.authenticateExposePeer(ctx, peerKey) + if err != nil { + return nil, err + } + + reverseProxyMgr := s.getReverseProxyManager() + if reverseProxyMgr == nil { + return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") + } + + serviceID, err := s.resolveServiceID(ctx, renewReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { + return nil, mapExposeError(ctx, err) + } + + return s.encryptResponse(peerKey, &proto.RenewExposeResponse{}) +} + +// StopExpose terminates an active expose session. +func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + stopReq := &proto.StopExposeRequest{} + peerKey, err := s.parseRequest(ctx, req, stopReq) + if err != nil { + return nil, err + } + + accountID, peer, err := s.authenticateExposePeer(ctx, peerKey) + if err != nil { + return nil, err + } + + reverseProxyMgr := s.getReverseProxyManager() + if reverseProxyMgr == nil { + return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") + } + + serviceID, err := s.resolveServiceID(ctx, stopReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { + return nil, mapExposeError(ctx, err) + } + + return s.encryptResponse(peerKey, &proto.StopExposeResponse{}) +} + +func mapExposeError(ctx context.Context, err error) error { + s, ok := internalStatus.FromError(err) + if !ok { + log.WithContext(ctx).Errorf("expose service error: %v", err) + return status.Errorf(codes.Internal, "internal error") + } + + switch s.Type() { + case internalStatus.InvalidArgument: + return status.Errorf(codes.InvalidArgument, "%s", s.Message) + case internalStatus.PermissionDenied: + return status.Errorf(codes.PermissionDenied, "%s", s.Message) + case internalStatus.NotFound: + return status.Errorf(codes.NotFound, "%s", s.Message) + case internalStatus.AlreadyExists: + return status.Errorf(codes.AlreadyExists, "%s", s.Message) + case internalStatus.PreconditionFailed: + return status.Errorf(codes.ResourceExhausted, "%s", s.Message) + default: + log.WithContext(ctx).Errorf("expose service error: %v", err) + return status.Errorf(codes.Internal, "internal error") + } +} + +func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) { + wgKey, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "internal error") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg) + if err != nil { + return nil, status.Errorf(codes.Internal, "encrypt response") + } + + return &proto.EncryptedMessage{ + WgPubKey: wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} + +func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) { + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) + if err != nil { + if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered") + } + return "", nil, status.Errorf(codes.Internal, "lookup account for peer") + } + + peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String()) + if err != nil { + return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered") + } + + return accountID, peer, nil +} + +func (s *Server) getReverseProxyManager() rpservice.Manager { + s.reverseProxyMu.RLock() + defer s.reverseProxyMu.RUnlock() + return s.reverseProxyManager +} + +// SetReverseProxyManager sets the reverse proxy manager on the server. +func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) { + s.reverseProxyMu.Lock() + defer s.reverseProxyMu.Unlock() + s.reverseProxyManager = mgr +} + +// resolveServiceID looks up the service by its globally unique domain. +func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) { + if domain == "" { + return "", status.Errorf(codes.InvalidArgument, "domain is required") + } + + svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain) + if err != nil { + return "", err + } + return svc.ID, nil +} + +func exposeProtocolToString(p proto.ExposeProtocol) (string, error) { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS: + return "http", nil + case proto.ExposeProtocol_EXPOSE_TCP: + return "tcp", nil + case proto.ExposeProtocol_EXPOSE_UDP: + return "udp", nil + case proto.ExposeProtocol_EXPOSE_TLS: + return "tls", nil + default: + return "", fmt.Errorf("unsupported expose protocol: %v", p) + } +} + +// exposeTargetProtocol returns the target protocol for the given expose protocol. +// For HTTP mode, this is http or https (the scheme used to connect to the backend). +// For L4 modes, this is tcp or udp (the transport used to connect to the backend). +func exposeTargetProtocol(p proto.ExposeProtocol) string { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTPS: + return rpservice.TargetProtoHTTPS + case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS: + return rpservice.TargetProtoTCP + case proto.ExposeProtocol_EXPOSE_UDP: + return rpservice.TargetProtoUDP + default: + return rpservice.TargetProtoHTTP + } +} diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index dcc37c639..acfd6eafb 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -1,28 +1,21 @@ package grpc import ( + "context" "crypto/rand" + "crypto/sha256" "crypto/subtle" "encoding/base64" + "encoding/hex" + "encoding/json" "fmt" - "sync" "time" + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" ) -// OneTimeTokenStore manages short-lived, single-use authentication tokens -// for proxy-to-management RPC authentication. Tokens are generated when -// a service is created and must be used exactly once by the proxy -// to authenticate a subsequent RPC call. -type OneTimeTokenStore struct { - tokens map[string]*tokenMetadata - mu sync.RWMutex - cleanup *time.Ticker - cleanupDone chan struct{} -} - -// tokenMetadata stores information about a one-time token type tokenMetadata struct { ServiceID string AccountID string @@ -30,20 +23,19 @@ type tokenMetadata struct { CreatedAt time.Time } -// NewOneTimeTokenStore creates a new token store with automatic cleanup -// of expired tokens. The cleanupInterval determines how often expired -// tokens are removed from memory. -func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { - store := &OneTimeTokenStore{ - tokens: make(map[string]*tokenMetadata), - cleanup: time.NewTicker(cleanupInterval), - cleanupDone: make(chan struct{}), +// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type OneTimeTokenStore struct { + cache *cache.Cache[string] + ctx context.Context +} + +// NewOneTimeTokenStore creates a token store using the provided shared cache store. +func NewOneTimeTokenStore(ctx context.Context, cacheStore store.StoreInterface) *OneTimeTokenStore { + return &OneTimeTokenStore{ + cache: cache.New[string](cacheStore), + ctx: ctx, } - - // Start background cleanup goroutine - go store.cleanupExpired() - - return store } // GenerateToken creates a new cryptographically secure one-time token @@ -52,25 +44,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { // // Returns the generated token string or an error if random generation fails. func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { - // Generate 32 bytes (256 bits) of cryptographically secure random data randomBytes := make([]byte, 32) if _, err := rand.Read(randomBytes); err != nil { return "", fmt.Errorf("failed to generate random token: %w", err) } - // Encode as URL-safe base64 for easy transmission in gRPC token := base64.URLEncoding.EncodeToString(randomBytes) + hashedToken := hashToken(token) - s.mu.Lock() - defer s.mu.Unlock() - - s.tokens[token] = &tokenMetadata{ + metadata := &tokenMetadata{ ServiceID: serviceID, AccountID: accountID, ExpiresAt: time.Now().Add(ttl), CreatedAt: time.Now(), } + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to serialize token metadata: %w", err) + } + + if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil { + return "", fmt.Errorf("failed to store token: %w", err) + } + log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", serviceID, accountID, ttl) @@ -88,80 +85,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time. // - Account ID doesn't match // - Reverse proxy ID doesn't match func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { - s.mu.Lock() - defer s.mu.Unlock() + hashedToken := hashToken(token) - metadata, exists := s.tokens[token] - if !exists { - log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", - serviceID, accountID) + metadataJSON, err := s.cache.Get(s.ctx, hashedToken) + if err != nil { + log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("invalid token") } - // Check expiration + metadata := &tokenMetadata{} + if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil { + log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err) + return fmt.Errorf("invalid token metadata") + } + if time.Now().After(metadata.ExpiresAt) { - delete(s.tokens, token) - log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", - serviceID, accountID) + log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("token expired") } - // Validate account ID using constant-time comparison (prevents timing attacks) if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { - log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", - metadata.AccountID, accountID) + log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID) return fmt.Errorf("account ID mismatch") } - // Validate service ID using constant-time comparison if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { - log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", - metadata.ServiceID, serviceID) + log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID) return fmt.Errorf("service ID mismatch") } - // Delete token immediately to enforce single-use - delete(s.tokens, token) + if err := s.cache.Delete(s.ctx, hashedToken); err != nil { + log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err) + } - log.Infof("Token validated and consumed for proxy %s in account %s", - serviceID, accountID) + log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID) return nil } -// cleanupExpired removes expired tokens in the background to prevent memory leaks -func (s *OneTimeTokenStore) cleanupExpired() { - for { - select { - case <-s.cleanup.C: - s.mu.Lock() - now := time.Now() - removed := 0 - for token, metadata := range s.tokens { - if now.After(metadata.ExpiresAt) { - delete(s.tokens, token) - removed++ - } - } - if removed > 0 { - log.Debugf("Cleaned up %d expired one-time tokens", removed) - } - s.mu.Unlock() - case <-s.cleanupDone: - return - } - } -} - -// Close stops the cleanup goroutine and releases resources -func (s *OneTimeTokenStore) Close() { - s.cleanup.Stop() - close(s.cleanupDone) -} - -// GetTokenCount returns the current number of tokens in the store (for debugging/metrics) -func (s *OneTimeTokenStore) GetTokenCount() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.tokens) +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) } diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go new file mode 100644 index 000000000..a1325256c --- /dev/null +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -0,0 +1,54 @@ +package grpc + +import ( + "context" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" + log "github.com/sirupsen/logrus" +) + +// PKCEVerifierStore manages PKCE verifiers for OAuth flows. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type PKCEVerifierStore struct { + cache *cache.Cache[string] + ctx context.Context +} + +// NewPKCEVerifierStore creates a PKCE verifier store using the provided shared cache store. +func NewPKCEVerifierStore(ctx context.Context, cacheStore store.StoreInterface) *PKCEVerifierStore { + return &PKCEVerifierStore{ + cache: cache.New[string](cacheStore), + ctx: ctx, + } +} + +// Store saves a PKCE verifier associated with an OAuth state parameter. +// The verifier is stored with the specified TTL and will be automatically deleted after expiration. +func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error { + if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store PKCE verifier: %w", err) + } + + log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl) + return nil +} + +// LoadAndDelete retrieves and removes a PKCE verifier for the given state. +// Returns the verifier and true if found, or empty string and false if not found. +// This enforces single-use semantics for PKCE verifiers. +func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) { + verifier, err := s.cache.Get(s.ctx, state) + if err != nil { + log.Debugf("PKCE verifier not found for state") + return "", false + } + + if err := s.cache.Delete(s.ctx, state); err != nil { + log.Warnf("Failed to delete PKCE verifier for state: %v", err) + } + + return verifier, true +} diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 4771d35af..a5e352e75 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" "net/url" "strings" "sync" @@ -18,20 +19,21 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/internals/modules/peers" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/management/proto" + nbstatus "github.com/netbirdio/netbird/shared/management/status" ) type ProxyOIDCConfig struct { @@ -45,12 +47,6 @@ type ProxyOIDCConfig struct { KeysLocation string } -// ClusterInfo contains information about a proxy cluster. -type ClusterInfo struct { - Address string - ConnectedProxies int -} - // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -58,17 +54,17 @@ type ProxyServiceServer struct { // Map of connected proxies: proxy_id -> proxy connection connectedProxies sync.Map - // Map of cluster address -> set of proxy IDs - clusterProxies sync.Map - - // Channel for broadcasting reverse proxy updates to all proxies - updatesChan chan *proto.ProxyMapping - // Manager for access logs accessLogManager accesslogs.Manager + mu sync.RWMutex // Manager for reverse proxy operations - reverseProxyManager reverseproxy.Manager + serviceManager rpservice.Manager + // ProxyController for service updates and cluster management + proxyController proxy.Controller + + // Manager for proxy connections + proxyManager proxy.Manager // Manager for peers peersManager peers.Manager @@ -82,84 +78,82 @@ type ProxyServiceServer struct { // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig - // TODO: use database to store these instead? - // pkceVerifiers stores PKCE code verifiers keyed by OAuth state. - // Entries expire after pkceVerifierTTL to prevent unbounded growth. - pkceVerifiers sync.Map - pkceCleanupCancel context.CancelFunc + // Store for PKCE verifiers + pkceVerifierStore *PKCEVerifierStore + + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute -type pkceEntry struct { - verifier string - createdAt time.Time -} - // proxyConnection represents a connected proxy type proxyConnection struct { - proxyID string - address string - stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.ProxyMapping - ctx context.Context - cancel context.CancelFunc + proxyID string + address string + capabilities *proto.ProxyCapabilities + stream proto.ProxyService_GetMappingUpdateServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ - updatesChan: make(chan *proto.ProxyMapping, 100), accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, + pkceVerifierStore: pkceStore, peersManager: peersManager, usersManager: usersManager, - pkceCleanupCancel: cancel, + proxyManager: proxyMgr, + cancel: cancel, } - go s.cleanupPKCEVerifiers(ctx) + go s.cleanupStaleProxies(ctx) return s } -// cleanupPKCEVerifiers periodically removes expired PKCE verifiers. -func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) { - ticker := time.NewTicker(pkceVerifierTTL) +// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes +func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { + ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: - now := time.Now() - s.pkceVerifiers.Range(func(key, value any) bool { - if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL { - s.pkceVerifiers.Delete(key) - } - return true - }) + if err := s.proxyManager.CleanupStale(ctx, 1*time.Hour); err != nil { + log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err) + } } } } // Close stops background goroutines. func (s *ProxyServiceServer) Close() { - s.pkceCleanupCancel() + s.cancel() } -func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) { - s.reverseProxyManager = manager +// SetServiceManager sets the service manager. Must be called before serving. +func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { + s.mu.Lock() + defer s.mu.Unlock() + s.serviceManager = manager +} + +// SetProxyController sets the proxy controller. Must be called before serving. +func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { + s.mu.Lock() + defer s.mu.Unlock() + s.proxyController = proxyController } // GetMappingUpdate handles the control stream with proxy clients func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { ctx := stream.Context() - peerInfo := "" - if p, ok := peer.FromContext(ctx); ok { - peerInfo = p.Addr.String() - } - + peerInfo := PeerIPFromContext(ctx) log.Infof("New proxy connection from %s", peerInfo) proxyID := req.GetProxyId() @@ -174,16 +168,38 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ - proxyID: proxyID, - address: proxyAddress, - stream: stream, - sendChan: make(chan *proto.ProxyMapping, 100), - ctx: connCtx, - cancel: cancel, + proxyID: proxyID, + address: proxyAddress, + capabilities: req.GetCapabilities(), + stream: stream, + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), + ctx: connCtx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) - s.addToCluster(conn.address, proxyID) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } + + // Register proxy in database with capabilities + var caps *proxy.Capabilities + if c := req.GetCapabilities(); c != nil { + caps = &proxy.Capabilities{ + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + SupportsCrowdsec: c.SupportsCrowdsec, + } + } + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) + s.connectedProxies.Delete(proxyID) + if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) + } + return status.Errorf(codes.Internal, "register proxy in database: %v", err) + } + log.WithFields(log.Fields{ "proxy_id": proxyID, "address": proxyAddress, @@ -191,8 +207,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { + if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + } + s.connectedProxies.Delete(proxyID) - s.removeFromCluster(conn.address, proxyID) + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) + } + cancel() log.Infof("Proxy %s disconnected", proxyID) }() @@ -204,6 +227,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest errChan := make(chan error, 2) go s.sender(conn, errChan) + // Start heartbeat goroutine + go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) + select { case err := <-errChan: return fmt.Errorf("send update to proxy %s: %w", proxyID, err) @@ -212,30 +238,36 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only services matching the proxy's cluster address are sent. -func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.reverseProxyManager.GetGlobalServices(ctx) - if err != nil { - return fmt.Errorf("get services from store: %w", err) - } +// heartbeat updates the proxy's last_seen timestamp every minute +func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + } + case <-ctx.Done(): + return + } + } +} + +// sendSnapshot sends the initial snapshot of services to the connecting proxy. +// Only entries matching the proxy's cluster address are sent. +func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } - var filtered []*reverseproxy.Service - for _, service := range services { - if !service.Enabled { - continue - } - if service.ProxyCluster == "" || service.ProxyCluster != conn.address { - continue - } - filtered = append(filtered, service) + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err } - if len(filtered) == 0 { + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { @@ -244,9 +276,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } - for i, service := range filtered { - // Generate one-time authentication token for each service in the snapshot - // Tokens are not persistent on the proxy, so we need to generate new ones on reconnection + for i, m := range mappings { + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{m}, + InitialSyncComplete: i == len(mappings)-1, + }); err != nil { + return fmt.Errorf("send proxy mapping: %w", err) + } + } + + return nil +} + +func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) + if err != nil { + return nil, fmt.Errorf("get services from store: %w", err) + } + + var mappings []*proto.ProxyMapping + for _, service := range services { + if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { + continue + } + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) if err != nil { log.WithFields(log.Fields{ @@ -256,25 +309,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec continue } - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{ - service.ToProtoMapping( - reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. - token, - s.GetOIDCValidationConfig(), - ), - }, - InitialSyncComplete: i == len(filtered)-1, - }); err != nil { - log.WithFields(log.Fields{ - "domain": service.Domain, - "account": service.AccountID, - }).WithError(err).Error("failed to send proxy mapping") - return fmt.Errorf("send proxy mapping: %w", err) + m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + if !proxyAcceptsMapping(conn, m) { + continue } + mappings = append(mappings, m) } - - return nil + return mappings, nil } // isProxyAddressValid validates a proxy address @@ -287,8 +328,8 @@ func isProxyAddressValid(addr string) bool { func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { - case msg := <-conn.sendChan: - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil { + case resp := <-conn.sendChan: + if err := conn.stream.Send(resp); err != nil { errChan <- err return } @@ -339,17 +380,17 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { +func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - msg := s.perProxyMessage(update, conn.proxyID) - if msg == nil { + resp := s.perProxyMessage(update, conn.proxyID) + if resp == nil { return true } select { - case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID) + case conn.sendChan <- resp: + log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) } @@ -393,81 +434,99 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string { return urls } -// addToCluster registers a proxy in a cluster. -func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) - proxySet.(*sync.Map).Store(proxyID, struct{}{}) - log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr) -} - -// removeFromCluster removes a proxy from a cluster. -func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok { - proxySet.(*sync.Map).Delete(proxyID) - log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr) - } -} - // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { +func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) { + updateResponse := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{update}, + } + if clusterAddr == "" { - s.SendServiceUpdate(update) + s.SendServiceUpdate(updateResponse) return } - proxySet, ok := s.clusterProxies.Load(clusterAddr) - if !ok { - log.Debugf("No proxies connected for cluster %s", clusterAddr) + if s.proxyController == nil { + log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr) + return + } + + proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr) + if len(proxyIDs) == 0 { + log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr) return } log.Debugf("Sending service update to cluster %s", clusterAddr) - proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { - proxyID := key.(string) - if connVal, ok := s.connectedProxies.Load(proxyID); ok { - conn := connVal.(*proxyConnection) - msg := s.perProxyMessage(update, proxyID) - if msg == nil { - return true - } - select { - case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) - default: - log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) - } + for _, proxyID := range proxyIDs { + connVal, ok := s.connectedProxies.Load(proxyID) + if !ok { + continue } + conn := connVal.(*proxyConnection) + if !proxyAcceptsMapping(conn, update) { + log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) + continue + } + msg := s.perProxyMessage(updateResponse, proxyID) + if msg == nil { + continue + } + select { + case conn.sendChan <- msg: + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + } + } +} + +// proxyAcceptsMapping returns whether the proxy should receive this mapping. +// Old proxies that never reported capabilities are skipped for non-TLS L4 +// mappings with a custom listen port, since they don't understand the +// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) +// are new enough to handle the mapping. TLS uses SNI routing and works on +// any proxy. Delete operations are always sent so proxies can clean up. +func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { return true - }) + } + if mapping.ListenPort == 0 || mapping.Mode == "tls" { + return true + } + // Old proxies that never reported capabilities don't understand + // custom port mappings. + return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil } // perProxyMessage returns a copy of update with a fresh one-time token for -// create/update operations. For delete operations the original message is -// returned unchanged because proxies do not need to authenticate for removal. +// create/update operations. For delete operations the original mapping is +// used unchanged because proxies do not need to authenticate for removal. // Returns nil if token generation fails (the proxy should be skipped). -func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping { - if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" { - return update +func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { + resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) + for _, mapping := range update.Mapping { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + resp = append(resp, mapping) + continue + } + + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + if err != nil { + log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) + return nil + } + + msg := shallowCloneMapping(mapping) + msg.AuthToken = token + resp = append(resp, msg) } - token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute) - if err != nil { - log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) - return nil + return &proto.GetMappingUpdateResponse{ + Mapping: resp, } - - msg := shallowCloneMapping(update) - msg.AuthToken = token - return msg } // shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the @@ -475,46 +534,22 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID // should be set on the copy. func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { return &proto.ProxyMapping{ - Type: m.Type, - Id: m.Id, - AccountId: m.AccountId, - Domain: m.Domain, - Path: m.Path, - Auth: m.Auth, - PassHostHeader: m.PassHostHeader, - RewriteRedirects: m.RewriteRedirects, + Type: m.Type, + Id: m.Id, + AccountId: m.AccountId, + Domain: m.Domain, + Path: m.Path, + Auth: m.Auth, + PassHostHeader: m.PassHostHeader, + RewriteRedirects: m.RewriteRedirects, + Mode: m.Mode, + ListenPort: m.ListenPort, + AccessRestrictions: m.AccessRestrictions, } } -// GetAvailableClusters returns information about all connected proxy clusters. -func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { - clusterCounts := make(map[string]int) - s.clusterProxies.Range(func(key, value interface{}) bool { - clusterAddr := key.(string) - proxySet := value.(*sync.Map) - count := 0 - proxySet.Range(func(_, _ interface{}) bool { - count++ - return true - }) - if count > 0 { - clusterCounts[clusterAddr] = count - } - return true - }) - - clusters := make([]ClusterInfo, 0, len(clusterCounts)) - for addr, count := range clusterCounts { - clusters = append(clusters, ClusterInfo{ - Address: addr, - ConnectedProxies: count, - }) - } - return clusters -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { - service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) @@ -533,18 +568,20 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen }, nil } -func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) { switch v := req.GetRequest().(type) { case *proto.AuthenticateRequest_Pin: return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) case *proto.AuthenticateRequest_Password: return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth) + case *proto.AuthenticateRequest_HeaderAuth: + return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths) default: return false, "", "" } } -func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) { if auth == nil || !auth.Enabled { log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID) return false, "", "" @@ -558,7 +595,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri return true, "pin-user", proxyauth.MethodPIN } -func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) { if auth == nil || !auth.Enabled { log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID) return false, "", "" @@ -572,6 +609,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID return true, "password-user", proxyauth.MethodPassword } +func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) { + if len(auths) == 0 { + log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID) + return false, "", "" + } + + headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName()) + + var lastErr error + for _, auth := range auths { + if auth == nil || !auth.Enabled { + continue + } + if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName { + continue + } + if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil { + lastErr = err + continue + } + return true, "header-user", proxyauth.MethodHeader + } + + if lastErr != nil { + s.logAuthenticationError(ctx, lastErr, "Header") + } + return false, "", "" +} + func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) { if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) { log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType) @@ -580,7 +646,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err } } -func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) { +func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) { if !authenticated || service.SessionPrivateKey == "" { return "", nil } @@ -600,7 +666,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic return token, nil } -// SendStatusUpdate handles status updates from proxy clients +// SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { accountID := req.GetAccountId() serviceID := req.GetServiceId() @@ -619,8 +685,19 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required") } + internalStatus := protoStatusToInternal(protoStatus) + + if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { + sErr, isNbErr := nbstatus.FromError(err) + if isNbErr && sErr.Type() == nbstatus.NotFound { + return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID) + } + log.WithContext(ctx).WithError(err).Error("failed to update service status") + return nil, status.Errorf(codes.Internal, "update service status: %v", err) + } + if certificateIssued { - if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { + if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err) } @@ -630,13 +707,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se }).Info("Certificate issued timestamp updated") } - internalStatus := protoStatusToInternal(protoStatus) - - if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { - log.WithContext(ctx).WithError(err).Error("failed to update service status") - return nil, status.Errorf(codes.Internal, "update service status: %v", err) - } - log.WithFields(log.Fields{ "service_id": serviceID, "account_id": accountID, @@ -646,23 +716,23 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return &proto.SendStatusUpdateResponse{}, nil } -// protoStatusToInternal maps proto status to internal status -func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus { +// protoStatusToInternal maps proto status to internal service status. +func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { switch protoStatus { case proto.ProxyStatus_PROXY_STATUS_PENDING: - return reverseproxy.StatusPending + return rpservice.StatusPending case proto.ProxyStatus_PROXY_STATUS_ACTIVE: - return reverseproxy.StatusActive + return rpservice.StatusActive case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED: - return reverseproxy.StatusTunnelNotCreated + return rpservice.StatusTunnelNotCreated case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING: - return reverseproxy.StatusCertificatePending + return rpservice.StatusCertificatePending case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED: - return reverseproxy.StatusCertificateFailed + return rpservice.StatusCertificateFailed case proto.ProxyStatus_PROXY_STATUS_ERROR: - return reverseproxy.StatusError + return rpservice.StatusError default: - return reverseproxy.StatusError + return rpservice.StatusError } } @@ -726,8 +796,11 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) } + if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" { + return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme") + } // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. - services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId()) + services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId()) if err != nil { log.WithContext(ctx).Errorf("failed to get account services: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) @@ -771,7 +844,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) codeVerifier := oauth2.GenerateVerifier() - s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) + if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil { + log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err) + return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err) + } return &proto.GetOIDCURLResponse{ Url: (&oauth2.Config{ @@ -790,8 +866,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig { // GetOIDCValidationConfig returns the OIDC configuration for token validation // in the format needed by ToProtoMapping. -func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig { - return reverseproxy.OIDCValidationConfig{ +func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return proxy.OIDCValidationConfig{ Issuer: s.oidcConfig.Issuer, Audiences: []string{s.oidcConfig.Audience}, KeysLocation: s.oidcConfig.KeysLocation, @@ -807,20 +883,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string { // ValidateState validates the state parameter from an OAuth callback. // Returns the original redirect URL if valid, or an error if invalid. +// The HMAC is verified before consuming the PKCE verifier to prevent +// an attacker from invalidating a legitimate user's auth flow. func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { - v, ok := s.pkceVerifiers.LoadAndDelete(state) - if !ok { - return "", "", errors.New("no verifier for state") - } - entry, ok := v.(pkceEntry) - if !ok { - return "", "", errors.New("invalid verifier for state") - } - if time.Since(entry.createdAt) > pkceVerifierTTL { - return "", "", errors.New("PKCE verifier expired") - } - verifier = entry.verifier - // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) parts := strings.Split(state, "|") if len(parts) != 3 { @@ -844,18 +909,24 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL return "", "", errors.New("invalid state signature") } + // Consume the PKCE verifier only after HMAC validation passes. + verifier, ok := s.pkceVerifierStore.LoadAndDelete(state) + if !ok { + return "", "", errors.New("no verifier for state") + } + return verifier, redirectURL, nil } // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { // Find the service by domain to get its signing key - services, err := s.reverseProxyManager.GetGlobalServices(ctx) + services, err := s.serviceManager.GetGlobalServices(ctx) if err != nil { return "", fmt.Errorf("get services: %w", err) } - var service *reverseproxy.Service + var service *rpservice.Service for _, svc := range services { if svc.Domain == domain { service = svc @@ -921,8 +992,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) } -func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { - services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID) +func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { + services, err := s.serviceManager.GetAccountServices(ctx, accountID) if err != nil { return nil, fmt.Errorf("get account services: %w", err) } @@ -1043,8 +1114,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } -func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) { - services, err := s.reverseProxyManager.GetGlobalServices(ctx) +func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) if err != nil { return nil, fmt.Errorf("get services: %w", err) } @@ -1058,7 +1129,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri return nil, fmt.Errorf("service not found for domain: %s", domain) } -func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error { +func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { return nil } @@ -1081,3 +1152,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, use return fmt.Errorf("user not in allowed groups") } + +func ptr[T any](v T) *T { return &v } diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go index 6daeab5f2..dd593dfa0 100644 --- a/management/internals/shared/grpc/proxy_auth.go +++ b/management/internals/shared/grpc/proxy_auth.go @@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter } func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) { - clientIP := peerIPFromContext(ctx) + clientIP := PeerIPFromContext(ctx) if clientIP != "" && i.failureLimiter.isLimited(clientIP) { return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts") diff --git a/management/internals/shared/grpc/proxy_auth_ratelimit.go b/management/internals/shared/grpc/proxy_auth_ratelimit.go index 447e531b0..78ab1bd20 100644 --- a/management/internals/shared/grpc/proxy_auth_ratelimit.go +++ b/management/internals/shared/grpc/proxy_auth_ratelimit.go @@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() { l.cancel() } -// peerIPFromContext extracts the client IP from the gRPC context. +// PeerIPFromContext extracts the client IP from the gRPC context. // Uses realip (from trusted proxy headers) first, falls back to the transport peer address. -func peerIPFromContext(ctx context.Context) clientIP { +func PeerIPFromContext(ctx context.Context) string { if addr, ok := realip.FromContext(ctx); ok { return addr.String() } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 84fb54923..0fa9a0dc1 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -8,40 +8,45 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/types" ) type mockReverseProxyManager struct { - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service err error } -func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + +func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { if m.err != nil { return nil, m.err } return m.proxiesByAccount[accountID], nil } -func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return nil, nil } -func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { - return []*reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { + return []*service.Service{}, nil } -func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) { + return &service.Service{}, nil } -func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) { + return &service.Service{}, nil } -func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) { + return &service.Service{}, nil } func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error { @@ -52,7 +57,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac return nil } -func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { +func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error { return nil } @@ -64,14 +69,32 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, return nil } -func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) { + return &service.Service{}, nil } func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { return "", nil } +func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return &service.ExposeServiceResponse{}, nil +} + +func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} + +func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + type mockUsersManager struct { users map[string]*types.User err error @@ -93,7 +116,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name string domain string userID string - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service users map[string]*types.User proxyErr error userErr error @@ -104,7 +127,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user not found", domain: "app.example.com", userID: "unknown-user", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{Domain: "app.example.com", AccountID: "account1"}}, }, users: map[string]*types.User{}, @@ -115,7 +138,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "proxy not found in user's account", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{}, + proxiesByAccount: map[string][]*service.Service{}, users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "account1"}, }, @@ -126,7 +149,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "proxy exists in different account - not accessible", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account2": {{Domain: "app.example.com", AccountID: "account2"}}, }, users: map[string]*types.User{ @@ -139,8 +162,8 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "no bearer auth configured - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ - "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, + proxiesByAccount: map[string][]*service.Service{ + "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}}, }, users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "account1"}, @@ -151,12 +174,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "bearer auth disabled - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{Enabled: false}, }, }}, }, @@ -169,12 +192,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "bearer auth enabled but no groups configured - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{}, }, @@ -190,12 +213,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user not in allowed groups", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -212,12 +235,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user in one of the allowed groups - allow access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -233,12 +256,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user in all allowed groups - allow access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -266,10 +289,10 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "multiple proxies in account - finds correct one", domain: "app2.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": { {Domain: "app1.example.com", AccountID: "account1"}, - {Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, + {Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}}, {Domain: "app3.example.com", AccountID: "account1"}, }, }, @@ -283,7 +306,7 @@ func TestValidateUserGroupAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := &ProxyServiceServer{ - reverseProxyManager: &mockReverseProxyManager{ + serviceManager: &mockReverseProxyManager{ proxiesByAccount: tt.proxiesByAccount, err: tt.proxyErr, }, @@ -310,7 +333,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name string accountID string domain string - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service err error expectProxy bool expectErr bool @@ -319,7 +342,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "proxy found", accountID: "account1", domain: "app.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": { {Domain: "other.example.com", AccountID: "account1"}, {Domain: "app.example.com", AccountID: "account1"}, @@ -332,7 +355,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "proxy not found in account", accountID: "account1", domain: "unknown.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{Domain: "app.example.com", AccountID: "account1"}}, }, expectProxy: false, @@ -342,7 +365,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "empty proxy list for account", accountID: "account1", domain: "app.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{}, + proxiesByAccount: map[string][]*service.Service{}, expectProxy: false, expectErr: true, }, @@ -360,7 +383,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := &ProxyServiceServer{ - reverseProxyManager: &mockReverseProxyManager{ + serviceManager: &mockReverseProxyManager{ proxiesByAccount: tt.proxiesByAccount, err: tt.err, }, diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 4c84e6010..de4e96d93 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -1,6 +1,7 @@ package grpc import ( + "context" "crypto/rand" "encoding/base64" "strings" @@ -8,57 +9,139 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/shared/management/proto" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + +type testProxyController struct { + mu sync.Mutex + clusterProxies map[string]map[string]struct{} +} + +func newTestProxyController() *testProxyController { + return &testProxyController{ + clusterProxies: make(map[string]map[string]struct{}), + } +} + +func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) { +} + +func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return proxy.OIDCValidationConfig{} +} + +func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error { + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.clusterProxies[clusterAddr]; !ok { + c.clusterProxies[clusterAddr] = make(map[string]struct{}) + } + c.clusterProxies[clusterAddr][proxyID] = struct{}{} + return nil +} + +func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error { + c.mu.Lock() + defer c.mu.Unlock() + if proxies, ok := c.clusterProxies[clusterAddr]; ok { + delete(proxies, proxyID) + } + return nil +} + +func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { + c.mu.Lock() + defer c.mu.Unlock() + proxies, ok := c.clusterProxies[clusterAddr] + if !ok { + return nil + } + result := make([]string, 0, len(proxies)) + for id := range proxies { + result = append(result, id) + } + return result +} + // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. -func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping { - ch := make(chan *proto.ProxyMapping, 10) +func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil) +} + +// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. +func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { + ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ - proxyID: proxyID, - address: clusterAddr, - sendChan: ch, + proxyID: proxyID, + address: clusterAddr, + capabilities: caps, + sendChan: ch, } s.connectedProxies.Store(proxyID, conn) - proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) - proxySet.(*sync.Map).Store(proxyID, struct{}{}) + _ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID) return ch } -func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { +// drainMapping drains a single ProxyMapping from the channel. +func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping { select { - case msg := <-ch: - return msg + case resp := <-ch: + if len(resp.Mapping) > 0 { + return resp.Mapping[0] + } + return nil case <-time.After(time.Second): return nil } } +// drainEmpty checks if a channel has no message within timeout. +func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { + select { + case <-ch: + return false + case <-time.After(100 * time.Millisecond): + return true + } +} + func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + ctx := context.Background() + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } + s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" const numProxies = 3 - channels := make([]chan *proto.ProxyMapping, numProxies) + channels := make([]chan *proto.GetMappingUpdateResponse, numProxies) for i := range numProxies { id := "proxy-" + string(rune('a'+i)) channels[i] = registerFakeProxy(s, id, cluster) } - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", @@ -68,14 +151,14 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { }, } - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) tokens := make([]string, numProxies) for i, ch := range channels { - msg := drainChannel(ch) + msg := drainMapping(ch) require.NotNil(t, msg, "proxy %d should receive a message", i) - assert.Equal(t, update.Domain, msg.Domain) - assert.Equal(t, update.Id, msg.Id) + assert.Equal(t, mapping.Domain, msg.Domain) + assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) tokens[i] = msg.AuthToken } @@ -96,64 +179,69 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { } func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + ctx := context.Background() + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } + s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" ch1 := registerFakeProxy(s, "proxy-a", cluster) ch2 := registerFakeProxy(s, "proxy-b", cluster) - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "service-1", AccountId: "account-1", Domain: "test.example.com", } - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) require.NotNil(t, msg1) require.NotNil(t, msg2) // Delete operations should not generate tokens assert.Empty(t, msg1.AuthToken) assert.Empty(t, msg2.AuthToken) - - // No tokens should have been created - assert.Equal(t, 0, tokenStore.GetTokenCount()) } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + ctx := context.Background() + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } + s.SetProxyController(newTestProxyController()) // Register proxies in different clusters (SendServiceUpdate broadcasts to all) ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") ch2 := registerFakeProxy(s, "proxy-b", "cluster-b") - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", Domain: "test.example.com", } + update := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{mapping}, + } + s.SendServiceUpdate(update) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) require.NotNil(t, msg1) require.NotNil(t, msg2) @@ -178,10 +266,14 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { } func TestOAuthState_NeverTheSame(t *testing.T) { + ctx := context.Background() + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } redirectURL := "https://app.example.com/callback" @@ -202,31 +294,370 @@ func TestOAuthState_NeverTheSame(t *testing.T) { } func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { + ctx := context.Background() + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Old format had only 2 parts: base64(url)|hmac - s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err := s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("base64url|hmac") + _, _, err = s.ValidateState("base64url|hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state format") } func TestValidateState_RejectsInvalidHMAC(t *testing.T) { + ctx := context.Background() + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Store with tampered HMAC - s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err := s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") + _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } + +func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + // Modern proxy reports capabilities. + chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Legacy proxy never reported capabilities (nil). + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) + + ctx := context.Background() + + // TLS passthrough with custom port: all proxies receive it (SNI routing). + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)") + + // TCP mapping with custom port: only modern proxy receives it. + tcpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tcp", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tcp", + ListenPort: 5432, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping") + assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping") + + // HTTP mapping (no listen port): both receive it. + httpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-http", + AccountId: "account-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}}, + } + + s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping") + + // Proxy that reports SupportsCustomPorts=false still receives custom-port + // mappings because it understands the protocol (it's new enough). + chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping") +} + +func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + // Legacy proxy (no capabilities) still receives TLS since it uses SNI. + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) + + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) + + msg := drainMapping(chLegacy) + assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)") +} + +// TestServiceModifyNotifications exercises every possible modification +// scenario for an existing service, verifying the correct update types +// reach the correct clusters. +func TestServiceModifyNotifications(t *testing.T) { + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + + newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + chs := map[string]chan *proto.GetMappingUpdateResponse{ + "cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + "cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + } + return s, chs + } + + httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}}, + } + } + + tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}}, + } + } + + ctx := context.Background() + + t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg, "cluster-a should receive update") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token") + assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update") + }) + + t.Run("auth config changed sends MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.Auth = &proto.Authentication{Password: true, Pin: true} + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.True(t, msg.Auth.Password) + assert.True(t, msg.Auth.Pin) + }) + + t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Equal(t, "tls", msg.Mode, "mode should be tls") + assert.Equal(t, int32(8443), msg.ListenPort) + assert.Len(t, msg.Path, 1, "should have one path entry with target address") + assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target) + }) + + t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Empty(t, msg.Mode, "mode should be empty for HTTP") + assert.True(t, len(msg.Path) > 0) + }) + + t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) { + s, chs := newServer() + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 9443 + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, int32(9443), msg.ListenPort) + }) + + t.Run("disable sends REMOVED to cluster", func(t *testing.T) { + s, chs := newServer() + // Manager sends Delete when service is disabled + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type) + assert.Empty(t, msg.AuthToken, "DELETE should not have token") + }) + + t.Run("enable sends CREATED to cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type) + assert.NotEmpty(t, msg.AuthToken) + }) + + t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) { + s, chs := newServer() + // This is the pattern the manager produces: + // 1. DELETE on old cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + // 2. CREATE on new cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b") + + msgA := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgA, "old cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type) + + msgB := drainMapping(chs["cluster-b"]) + require.NotNil(t, msgB, "new cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type) + assert.NotEmpty(t, msgB.AuthToken) + }) + + t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) { + s, chs := newServer() + // Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain). + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msgDel := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgDel, "same cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type) + + msgCreate := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgCreate, "same cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type) + assert.NotEmpty(t, msgCreate.AuthToken) + }) + + t.Run("TLS passthrough sent to all proxies", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) + + // TLS passthrough works on all proxies regardless of custom port support + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) + + msgModern := drainMapping(chModern) + require.NotNil(t, msgModern, "modern proxy receives TLS update") + assert.Equal(t, "tls", msgModern.Mode) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough") + }) + + t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chLegacy := registerFakeProxy(s, "legacy", cluster) + + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 0 // default port + s.SendServiceUpdateToCluster(ctx, mapping, cluster) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port") + }) + + t.Run("passthrough and rewrite flags propagated", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.PassHostHeader = true + mapping.RewriteRedirects = true + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.True(t, msg.PassHostHeader) + assert.True(t, msg.RewriteRedirects) + }) + + t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, + Id: "svc-multi", + AccountId: "acct-1", + Domain: "multi.example.com", + Path: []*proto.PathMapping{ + {Path: "/", Target: "http://10.0.0.1:8080"}, + {Path: "/api", Target: "http://10.0.0.2:9090"}, + {Path: "/ws", Target: "http://10.0.0.3:3000"}, + }, + } + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + require.Len(t, msg.Path, 3, "all paths should be present") + assert.Equal(t, "/", msg.Path[0].Path) + assert.Equal(t, "/api", msg.Path[1].Path) + assert.Equal(t, "/ws", msg.Path[2].Path) + }) +} diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 0167aca07..6e8358f02 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/shared/management/client/common" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/job" @@ -80,6 +81,9 @@ type Server struct { syncSem atomic.Int32 syncLimEnabled bool syncLim int32 + + reverseProxyManager rpservice.Manager + reverseProxyMu sync.RWMutex } // NewServer creates a new Management server @@ -326,13 +330,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) - } - unlock() unlock = nil + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + } log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart)) s.syncSem.Add(-1) @@ -739,13 +742,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) - } - log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart)) - }() - if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) @@ -795,6 +791,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto return nil, status.Errorf(codes.Internal, "failed logging in peer") } + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) + } + log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart)) + return &proto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: encryptedResp, diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index f76d3ada0..d1d7fc8b7 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -13,7 +13,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -34,11 +35,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) require.NoError(t, err) - proxyManager := &testValidateSessionProxyManager{store: testStore} + serviceManager := &testValidateSessionServiceManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore} + proxyManager := &testValidateSessionProxyManager{} - proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) - proxyService.SetProxyManager(proxyManager) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) + + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -54,7 +59,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) pubKey, privKey := generateSessionKeyPair(t) - testProxy := &reverseproxy.Service{ + testProxy := &service.Service{ ID: "testProxyId", AccountID: "testAccountId", Name: "Test Proxy", @@ -62,15 +67,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) Enabled: true, SessionPrivateKey: privKey, SessionPublicKey: pubKey, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, }, }, } require.NoError(t, testStore.CreateService(ctx, testProxy)) - restrictedProxy := &reverseproxy.Service{ + restrictedProxy := &service.Service{ ID: "restrictedProxyId", AccountID: "testAccountId", Name: "Restricted Proxy", @@ -78,8 +83,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) Enabled: true, SessionPrivateKey: privKey, SessionPublicKey: pubKey, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"allowedGroupId"}, }, @@ -196,7 +201,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) { require.NoError(t, err) assert.False(t, resp.Valid, "Unknown proxy should be denied") - assert.Equal(t, "proxy_not_found", resp.DeniedReason) + assert.Equal(t, "service_not_found", resp.DeniedReason) } func TestValidateSession_InvalidToken(t *testing.T) { @@ -239,62 +244,122 @@ func TestValidateSession_MissingToken(t *testing.T) { assert.Contains(t, resp.DeniedReason, "missing") } -type testValidateSessionProxyManager struct { +type testValidateSessionServiceManager struct { store store.Store } -func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { +func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { +func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { +func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error { return nil } -func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error { +func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { return nil } -func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error { + return nil +} + +func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetServices(ctx, store.LockingStrengthNone) } -func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) } -func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } -func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { +func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { return "", nil } +func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return nil, nil +} + +func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} + +func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + +type testValidateSessionProxyManager struct{} + +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { + return nil +} + +func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/account.go b/management/server/account.go index 1e35d4ad1..7d53cef03 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -15,7 +15,7 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/shared/auth" @@ -83,9 +83,9 @@ type DefaultAccountManager struct { requestBuffer *AccountRequestBuffer - proxyController port_forwarding.Controller - settingsManager settings.Manager - reverseProxyManager reverseproxy.Manager + proxyController port_forwarding.Controller + settingsManager settings.Manager + serviceManager service.Manager // config contains the management server configuration config *nbconfig.Config @@ -115,8 +115,8 @@ type DefaultAccountManager struct { var _ account.Manager = (*DefaultAccountManager)(nil) -func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { - am.reverseProxyManager = serviceManager +func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) { + am.serviceManager = serviceManager } func isUniqueConstraintError(err error) bool { @@ -181,7 +181,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] return modified, newUserAutoGroups, newGroupsToCreate, nil } -// BuildManager creates a new DefaultAccountManager with a provided Store +// BuildManager creates a new DefaultAccountManager with all dependencies. func BuildManager( ctx context.Context, config *nbconfig.Config, @@ -199,6 +199,7 @@ func BuildManager( settingsManager settings.Manager, permissionsManager permissions.Manager, disableDefaultPolicy bool, + sharedCacheStore cacheStore.StoreInterface, ) (*DefaultAccountManager, error) { start := time.Now() defer func() { @@ -247,16 +248,12 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) - if err != nil { - return nil, fmt.Errorf("getting cache store: %s", err) - } - am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) - am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) + am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore) + am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore) if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx, cacheStore) + err := am.warmupIDPCache(ctx, sharedCacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? @@ -335,7 +332,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.DNSDomain != newSettings.DNSDomain || - oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || + oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { updateAccountPeers = true } @@ -376,6 +374,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } @@ -394,7 +394,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) } if reloadReverseProxy { - if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { + if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) } } @@ -492,6 +492,31 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateAlwaysSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { + if newSettings.AutoUpdateAlways { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysDisabled, nil) + } + } +} + +func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + oldEnabled := oldSettings.PeerExposeEnabled + newEnabled := newSettings.PeerExposeEnabled + + if oldEnabled == newEnabled { + return + } + + event := activity.AccountPeerExposeEnabled + if !newEnabled { + event = activity.AccountPeerExposeDisabled + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -1358,9 +1383,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. - userAuth.Domain = am.singleAccountModeDomain - userAuth.DomainCategory = types.PrivateCategory - log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + err := am.updateUserAuthWithSingleMode(ctx, &userAuth) + if err != nil { + return "", "", err + } } accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) @@ -1393,6 +1419,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u return accountID, user.Id, nil } +// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account +func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error { + userAuth.DomainCategory = types.PrivateCategory + userAuth.Domain = am.singleAccountModeDomain + + accountID, err := am.Store.GetAnyAccountID(ctx) + if err != nil { + if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound { + return err + } + log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode") + return nil + } + + if accountID == "" { + log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode") + return nil + } + + domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + userAuth.Domain = domain + + log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + return nil +} + // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 207ab71d6..b4516d512 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -1,12 +1,14 @@ package account +//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + import ( "context" "net" "net/netip" "time" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/shared/auth" nbdns "github.com/netbirdio/netbird/dns" @@ -61,11 +63,11 @@ type Manager interface { GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error - UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) - AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) @@ -73,7 +75,7 @@ type Manager interface { GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error @@ -140,5 +142,5 @@ type Manager interface { CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) - SetServiceManager(serviceManager reverseproxy.Manager) + SetServiceManager(serviceManager service.Manager) } diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go new file mode 100644 index 000000000..36e5fe39f --- /dev/null +++ b/management/server/account/manager_mock.go @@ -0,0 +1,1738 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./manager.go + +// Package account is a generated GoMock package. +package account + +import ( + context "context" + net "net" + netip "net/netip" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + dns "github.com/netbirdio/netbird/dns" + service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + activity "github.com/netbirdio/netbird/management/server/activity" + idp "github.com/netbirdio/netbird/management/server/idp" + peer "github.com/netbirdio/netbird/management/server/peer" + posture "github.com/netbirdio/netbird/management/server/posture" + store "github.com/netbirdio/netbird/management/server/store" + types "github.com/netbirdio/netbird/management/server/types" + users "github.com/netbirdio/netbird/management/server/users" + route "github.com/netbirdio/netbird/route" + auth "github.com/netbirdio/netbird/shared/auth" + domain "github.com/netbirdio/netbird/shared/management/domain" +) + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// AcceptUserInvite mocks base method. +func (m *MockManager) AcceptUserInvite(ctx context.Context, token, password string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUserInvite", ctx, token, password) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcceptUserInvite indicates an expected call of AcceptUserInvite. +func (mr *MockManagerMockRecorder) AcceptUserInvite(ctx, token, password interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUserInvite", reflect.TypeOf((*MockManager)(nil).AcceptUserInvite), ctx, token, password) +} + +// AccountExists mocks base method. +func (m *MockManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccountExists", ctx, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AccountExists indicates an expected call of AccountExists. +func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccountExists", reflect.TypeOf((*MockManager)(nil).AccountExists), ctx, accountID) +} + +// AddPeer mocks base method. +func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// AddPeer indicates an expected call of AddPeer. +func (mr *MockManagerMockRecorder) AddPeer(ctx, accountID, setupKey, userID, p, temporary interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPeer", reflect.TypeOf((*MockManager)(nil).AddPeer), ctx, accountID, setupKey, userID, p, temporary) +} + +// ApproveUser mocks base method. +func (m *MockManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApproveUser", ctx, accountID, initiatorUserID, targetUserID) + ret0, _ := ret[0].(*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApproveUser indicates an expected call of ApproveUser. +func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID, targetUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveUser", reflect.TypeOf((*MockManager)(nil).ApproveUser), ctx, accountID, initiatorUserID, targetUserID) +} + +// BufferUpdateAccountPeers mocks base method. +func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) +} + +// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. +func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID) +} + +// BuildUserInfosForAccount mocks base method. +func (m *MockManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BuildUserInfosForAccount", ctx, accountID, initiatorUserID, accountUsers) + ret0, _ := ret[0].(map[string]*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BuildUserInfosForAccount indicates an expected call of BuildUserInfosForAccount. +func (mr *MockManagerMockRecorder) BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildUserInfosForAccount", reflect.TypeOf((*MockManager)(nil).BuildUserInfosForAccount), ctx, accountID, initiatorUserID, accountUsers) +} + +// CreateGroup mocks base method. +func (m *MockManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateGroup", ctx, accountID, userID, group) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateGroup indicates an expected call of CreateGroup. +func (mr *MockManagerMockRecorder) CreateGroup(ctx, accountID, userID, group interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroup", reflect.TypeOf((*MockManager)(nil).CreateGroup), ctx, accountID, userID, group) +} + +// CreateGroups mocks base method. +func (m *MockManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateGroups", ctx, accountID, userID, newGroups) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateGroups indicates an expected call of CreateGroups. +func (mr *MockManagerMockRecorder) CreateGroups(ctx, accountID, userID, newGroups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroups", reflect.TypeOf((*MockManager)(nil).CreateGroups), ctx, accountID, userID, newGroups) +} + +// CreateIdentityProvider mocks base method. +func (m *MockManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateIdentityProvider", ctx, accountID, userID, idp) + ret0, _ := ret[0].(*types.IdentityProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateIdentityProvider indicates an expected call of CreateIdentityProvider. +func (mr *MockManagerMockRecorder) CreateIdentityProvider(ctx, accountID, userID, idp interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateIdentityProvider", reflect.TypeOf((*MockManager)(nil).CreateIdentityProvider), ctx, accountID, userID, idp) +} + +// CreateNameServerGroup mocks base method. +func (m *MockManager) CreateNameServerGroup(ctx context.Context, accountID, name, description string, nameServerList []dns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*dns.NameServerGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateNameServerGroup", ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) + ret0, _ := ret[0].(*dns.NameServerGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateNameServerGroup indicates an expected call of CreateNameServerGroup. +func (mr *MockManagerMockRecorder) CreateNameServerGroup(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNameServerGroup", reflect.TypeOf((*MockManager)(nil).CreateNameServerGroup), ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) +} + +// CreatePAT mocks base method. +func (m *MockManager) CreatePAT(ctx context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePAT", ctx, accountID, initiatorUserID, targetUserID, tokenName, expiresIn) + ret0, _ := ret[0].(*types.PersonalAccessTokenGenerated) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreatePAT indicates an expected call of CreatePAT. +func (mr *MockManagerMockRecorder) CreatePAT(ctx, accountID, initiatorUserID, targetUserID, tokenName, expiresIn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePAT", reflect.TypeOf((*MockManager)(nil).CreatePAT), ctx, accountID, initiatorUserID, targetUserID, tokenName, expiresIn) +} + +// CreatePeerJob mocks base method. +func (m *MockManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePeerJob", ctx, accountID, peerID, userID, job) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreatePeerJob indicates an expected call of CreatePeerJob. +func (mr *MockManagerMockRecorder) CreatePeerJob(ctx, accountID, peerID, userID, job interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePeerJob", reflect.TypeOf((*MockManager)(nil).CreatePeerJob), ctx, accountID, peerID, userID, job) +} + +// CreateRoute mocks base method. +func (m *MockManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute, skipAutoApply bool) (*route.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateRoute", ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupIDs, enabled, userID, keepRoute, skipAutoApply) + ret0, _ := ret[0].(*route.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateRoute indicates an expected call of CreateRoute. +func (mr *MockManagerMockRecorder) CreateRoute(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupIDs, enabled, userID, keepRoute, skipAutoApply interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRoute", reflect.TypeOf((*MockManager)(nil).CreateRoute), ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupIDs, enabled, userID, keepRoute, skipAutoApply) +} + +// CreateSetupKey mocks base method. +func (m *MockManager) CreateSetupKey(ctx context.Context, accountID, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral, allowExtraDNSLabels bool) (*types.SetupKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSetupKey", ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral, allowExtraDNSLabels) + ret0, _ := ret[0].(*types.SetupKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSetupKey indicates an expected call of CreateSetupKey. +func (mr *MockManagerMockRecorder) CreateSetupKey(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral, allowExtraDNSLabels interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSetupKey", reflect.TypeOf((*MockManager)(nil).CreateSetupKey), ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral, allowExtraDNSLabels) +} + +// CreateUser mocks base method. +func (m *MockManager) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", ctx, accountID, initiatorUserID, key) + ret0, _ := ret[0].(*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser. +func (mr *MockManagerMockRecorder) CreateUser(ctx, accountID, initiatorUserID, key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockManager)(nil).CreateUser), ctx, accountID, initiatorUserID, key) +} + +// CreateUserInvite mocks base method. +func (m *MockManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUserInvite", ctx, accountID, initiatorUserID, invite, expiresIn) + ret0, _ := ret[0].(*types.UserInvite) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUserInvite indicates an expected call of CreateUserInvite. +func (mr *MockManagerMockRecorder) CreateUserInvite(ctx, accountID, initiatorUserID, invite, expiresIn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUserInvite", reflect.TypeOf((*MockManager)(nil).CreateUserInvite), ctx, accountID, initiatorUserID, invite, expiresIn) +} + +// DeleteAccount mocks base method. +func (m *MockManager) DeleteAccount(ctx context.Context, accountID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccount", ctx, accountID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccount indicates an expected call of DeleteAccount. +func (mr *MockManagerMockRecorder) DeleteAccount(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccount", reflect.TypeOf((*MockManager)(nil).DeleteAccount), ctx, accountID, userID) +} + +// DeleteGroup mocks base method. +func (m *MockManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteGroup", ctx, accountId, userId, groupID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteGroup indicates an expected call of DeleteGroup. +func (mr *MockManagerMockRecorder) DeleteGroup(ctx, accountId, userId, groupID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroup", reflect.TypeOf((*MockManager)(nil).DeleteGroup), ctx, accountId, userId, groupID) +} + +// DeleteGroups mocks base method. +func (m *MockManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteGroups", ctx, accountId, userId, groupIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteGroups indicates an expected call of DeleteGroups. +func (mr *MockManagerMockRecorder) DeleteGroups(ctx, accountId, userId, groupIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroups", reflect.TypeOf((*MockManager)(nil).DeleteGroups), ctx, accountId, userId, groupIDs) +} + +// DeleteIdentityProvider mocks base method. +func (m *MockManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteIdentityProvider", ctx, accountID, idpID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteIdentityProvider indicates an expected call of DeleteIdentityProvider. +func (mr *MockManagerMockRecorder) DeleteIdentityProvider(ctx, accountID, idpID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteIdentityProvider", reflect.TypeOf((*MockManager)(nil).DeleteIdentityProvider), ctx, accountID, idpID, userID) +} + +// DeleteNameServerGroup mocks base method. +func (m *MockManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteNameServerGroup", ctx, accountID, nsGroupID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteNameServerGroup indicates an expected call of DeleteNameServerGroup. +func (mr *MockManagerMockRecorder) DeleteNameServerGroup(ctx, accountID, nsGroupID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteNameServerGroup", reflect.TypeOf((*MockManager)(nil).DeleteNameServerGroup), ctx, accountID, nsGroupID, userID) +} + +// DeletePAT mocks base method. +func (m *MockManager) DeletePAT(ctx context.Context, accountID, initiatorUserID, targetUserID, tokenID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePAT", ctx, accountID, initiatorUserID, targetUserID, tokenID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePAT indicates an expected call of DeletePAT. +func (mr *MockManagerMockRecorder) DeletePAT(ctx, accountID, initiatorUserID, targetUserID, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePAT", reflect.TypeOf((*MockManager)(nil).DeletePAT), ctx, accountID, initiatorUserID, targetUserID, tokenID) +} + +// DeletePeer mocks base method. +func (m *MockManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeer", ctx, accountID, peerID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeer indicates an expected call of DeletePeer. +func (mr *MockManagerMockRecorder) DeletePeer(ctx, accountID, peerID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockManager)(nil).DeletePeer), ctx, accountID, peerID, userID) +} + +// DeletePolicy mocks base method. +func (m *MockManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePolicy", ctx, accountID, policyID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePolicy indicates an expected call of DeletePolicy. +func (mr *MockManagerMockRecorder) DeletePolicy(ctx, accountID, policyID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePolicy", reflect.TypeOf((*MockManager)(nil).DeletePolicy), ctx, accountID, policyID, userID) +} + +// DeletePostureChecks mocks base method. +func (m *MockManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePostureChecks", ctx, accountID, postureChecksID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePostureChecks indicates an expected call of DeletePostureChecks. +func (mr *MockManagerMockRecorder) DeletePostureChecks(ctx, accountID, postureChecksID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockManager)(nil).DeletePostureChecks), ctx, accountID, postureChecksID, userID) +} + +// DeleteRegularUsers mocks base method. +func (m *MockManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRegularUsers", ctx, accountID, initiatorUserID, targetUserIDs, userInfos) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRegularUsers indicates an expected call of DeleteRegularUsers. +func (mr *MockManagerMockRecorder) DeleteRegularUsers(ctx, accountID, initiatorUserID, targetUserIDs, userInfos interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRegularUsers", reflect.TypeOf((*MockManager)(nil).DeleteRegularUsers), ctx, accountID, initiatorUserID, targetUserIDs, userInfos) +} + +// DeleteRoute mocks base method. +func (m *MockManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRoute", ctx, accountID, routeID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRoute indicates an expected call of DeleteRoute. +func (mr *MockManagerMockRecorder) DeleteRoute(ctx, accountID, routeID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRoute", reflect.TypeOf((*MockManager)(nil).DeleteRoute), ctx, accountID, routeID, userID) +} + +// DeleteSetupKey mocks base method. +func (m *MockManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSetupKey", ctx, accountID, userID, keyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSetupKey indicates an expected call of DeleteSetupKey. +func (mr *MockManagerMockRecorder) DeleteSetupKey(ctx, accountID, userID, keyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockManager)(nil).DeleteSetupKey), ctx, accountID, userID, keyID) +} + +// DeleteUser mocks base method. +func (m *MockManager) DeleteUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUser", ctx, accountID, initiatorUserID, targetUserID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUser indicates an expected call of DeleteUser. +func (mr *MockManagerMockRecorder) DeleteUser(ctx, accountID, initiatorUserID, targetUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockManager)(nil).DeleteUser), ctx, accountID, initiatorUserID, targetUserID) +} + +// DeleteUserInvite mocks base method. +func (m *MockManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserInvite", ctx, accountID, initiatorUserID, inviteID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserInvite indicates an expected call of DeleteUserInvite. +func (mr *MockManagerMockRecorder) DeleteUserInvite(ctx, accountID, initiatorUserID, inviteID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserInvite", reflect.TypeOf((*MockManager)(nil).DeleteUserInvite), ctx, accountID, initiatorUserID, inviteID) +} + +// FindExistingPostureCheck mocks base method. +func (m *MockManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindExistingPostureCheck", accountID, checks) + ret0, _ := ret[0].(*posture.Checks) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindExistingPostureCheck indicates an expected call of FindExistingPostureCheck. +func (mr *MockManagerMockRecorder) FindExistingPostureCheck(accountID, checks interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindExistingPostureCheck", reflect.TypeOf((*MockManager)(nil).FindExistingPostureCheck), accountID, checks) +} + +// GetAccount mocks base method. +func (m *MockManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccount", ctx, accountID) + ret0, _ := ret[0].(*types.Account) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccount indicates an expected call of GetAccount. +func (mr *MockManagerMockRecorder) GetAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccount", reflect.TypeOf((*MockManager)(nil).GetAccount), ctx, accountID) +} + +// GetAccountByID mocks base method. +func (m *MockManager) GetAccountByID(ctx context.Context, accountID, userID string) (*types.Account, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountByID", ctx, accountID, userID) + ret0, _ := ret[0].(*types.Account) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountByID indicates an expected call of GetAccountByID. +func (mr *MockManagerMockRecorder) GetAccountByID(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountByID", reflect.TypeOf((*MockManager)(nil).GetAccountByID), ctx, accountID, userID) +} + +// GetAccountIDByUserID mocks base method. +func (m *MockManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountIDByUserID", ctx, userAuth) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountIDByUserID indicates an expected call of GetAccountIDByUserID. +func (mr *MockManagerMockRecorder) GetAccountIDByUserID(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountIDByUserID", reflect.TypeOf((*MockManager)(nil).GetAccountIDByUserID), ctx, userAuth) +} + +// GetAccountIDForPeerKey mocks base method. +func (m *MockManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountIDForPeerKey", ctx, peerKey) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountIDForPeerKey indicates an expected call of GetAccountIDForPeerKey. +func (mr *MockManagerMockRecorder) GetAccountIDForPeerKey(ctx, peerKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountIDForPeerKey", reflect.TypeOf((*MockManager)(nil).GetAccountIDForPeerKey), ctx, peerKey) +} + +// GetAccountIDFromUserAuth mocks base method. +func (m *MockManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountIDFromUserAuth", ctx, userAuth) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetAccountIDFromUserAuth indicates an expected call of GetAccountIDFromUserAuth. +func (mr *MockManagerMockRecorder) GetAccountIDFromUserAuth(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountIDFromUserAuth", reflect.TypeOf((*MockManager)(nil).GetAccountIDFromUserAuth), ctx, userAuth) +} + +// GetAccountMeta mocks base method. +func (m *MockManager) GetAccountMeta(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountMeta", ctx, accountID, userID) + ret0, _ := ret[0].(*types.AccountMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountMeta indicates an expected call of GetAccountMeta. +func (mr *MockManagerMockRecorder) GetAccountMeta(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountMeta", reflect.TypeOf((*MockManager)(nil).GetAccountMeta), ctx, accountID, userID) +} + +// GetAccountOnboarding mocks base method. +func (m *MockManager) GetAccountOnboarding(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountOnboarding", ctx, accountID, userID) + ret0, _ := ret[0].(*types.AccountOnboarding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountOnboarding indicates an expected call of GetAccountOnboarding. +func (mr *MockManagerMockRecorder) GetAccountOnboarding(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountOnboarding", reflect.TypeOf((*MockManager)(nil).GetAccountOnboarding), ctx, accountID, userID) +} + +// GetAccountSettings mocks base method. +func (m *MockManager) GetAccountSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountSettings", ctx, accountID, userID) + ret0, _ := ret[0].(*types.Settings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountSettings indicates an expected call of GetAccountSettings. +func (mr *MockManagerMockRecorder) GetAccountSettings(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountSettings", reflect.TypeOf((*MockManager)(nil).GetAccountSettings), ctx, accountID, userID) +} + +// GetAllGroups mocks base method. +func (m *MockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllGroups", ctx, accountID, userID) + ret0, _ := ret[0].([]*types.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllGroups indicates an expected call of GetAllGroups. +func (mr *MockManagerMockRecorder) GetAllGroups(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllGroups", reflect.TypeOf((*MockManager)(nil).GetAllGroups), ctx, accountID, userID) +} + +// GetAllPATs mocks base method. +func (m *MockManager) GetAllPATs(ctx context.Context, accountID, initiatorUserID, targetUserID string) ([]*types.PersonalAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllPATs", ctx, accountID, initiatorUserID, targetUserID) + ret0, _ := ret[0].([]*types.PersonalAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllPATs indicates an expected call of GetAllPATs. +func (mr *MockManagerMockRecorder) GetAllPATs(ctx, accountID, initiatorUserID, targetUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPATs", reflect.TypeOf((*MockManager)(nil).GetAllPATs), ctx, accountID, initiatorUserID, targetUserID) +} + +// GetAllPeerJobs mocks base method. +func (m *MockManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllPeerJobs", ctx, accountID, userID, peerID) + ret0, _ := ret[0].([]*types.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllPeerJobs indicates an expected call of GetAllPeerJobs. +func (mr *MockManagerMockRecorder) GetAllPeerJobs(ctx, accountID, userID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPeerJobs", reflect.TypeOf((*MockManager)(nil).GetAllPeerJobs), ctx, accountID, userID, peerID) +} + +// GetCurrentUserInfo mocks base method. +func (m *MockManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCurrentUserInfo", ctx, userAuth) + ret0, _ := ret[0].(*users.UserInfoWithPermissions) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCurrentUserInfo indicates an expected call of GetCurrentUserInfo. +func (mr *MockManagerMockRecorder) GetCurrentUserInfo(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentUserInfo", reflect.TypeOf((*MockManager)(nil).GetCurrentUserInfo), ctx, userAuth) +} + +// GetDNSSettings mocks base method. +func (m *MockManager) GetDNSSettings(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNSSettings", ctx, accountID, userID) + ret0, _ := ret[0].(*types.DNSSettings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDNSSettings indicates an expected call of GetDNSSettings. +func (mr *MockManagerMockRecorder) GetDNSSettings(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSSettings", reflect.TypeOf((*MockManager)(nil).GetDNSSettings), ctx, accountID, userID) +} + +// GetEvents mocks base method. +func (m *MockManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEvents", ctx, accountID, userID) + ret0, _ := ret[0].([]*activity.Event) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEvents indicates an expected call of GetEvents. +func (mr *MockManagerMockRecorder) GetEvents(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvents", reflect.TypeOf((*MockManager)(nil).GetEvents), ctx, accountID, userID) +} + +// GetExternalCacheManager mocks base method. +func (m *MockManager) GetExternalCacheManager() ExternalCacheManager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExternalCacheManager") + ret0, _ := ret[0].(ExternalCacheManager) + return ret0 +} + +// GetExternalCacheManager indicates an expected call of GetExternalCacheManager. +func (mr *MockManagerMockRecorder) GetExternalCacheManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExternalCacheManager", reflect.TypeOf((*MockManager)(nil).GetExternalCacheManager)) +} + +// GetGroup mocks base method. +func (m *MockManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroup", ctx, accountId, groupID, userID) + ret0, _ := ret[0].(*types.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroup indicates an expected call of GetGroup. +func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroup", reflect.TypeOf((*MockManager)(nil).GetGroup), ctx, accountId, groupID, userID) +} + +// GetGroupByName mocks base method. +func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID) + ret0, _ := ret[0].(*types.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroupByName indicates an expected call of GetGroupByName. +func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID) +} + +// GetIdentityProvider mocks base method. +func (m *MockManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIdentityProvider", ctx, accountID, idpID, userID) + ret0, _ := ret[0].(*types.IdentityProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIdentityProvider indicates an expected call of GetIdentityProvider. +func (mr *MockManagerMockRecorder) GetIdentityProvider(ctx, accountID, idpID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIdentityProvider", reflect.TypeOf((*MockManager)(nil).GetIdentityProvider), ctx, accountID, idpID, userID) +} + +// GetIdentityProviders mocks base method. +func (m *MockManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIdentityProviders", ctx, accountID, userID) + ret0, _ := ret[0].([]*types.IdentityProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIdentityProviders indicates an expected call of GetIdentityProviders. +func (mr *MockManagerMockRecorder) GetIdentityProviders(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIdentityProviders", reflect.TypeOf((*MockManager)(nil).GetIdentityProviders), ctx, accountID, userID) +} + +// GetIdpManager mocks base method. +func (m *MockManager) GetIdpManager() idp.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIdpManager") + ret0, _ := ret[0].(idp.Manager) + return ret0 +} + +// GetIdpManager indicates an expected call of GetIdpManager. +func (mr *MockManagerMockRecorder) GetIdpManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIdpManager", reflect.TypeOf((*MockManager)(nil).GetIdpManager)) +} + +// GetNameServerGroup mocks base method. +func (m *MockManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*dns.NameServerGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNameServerGroup", ctx, accountID, userID, nsGroupID) + ret0, _ := ret[0].(*dns.NameServerGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNameServerGroup indicates an expected call of GetNameServerGroup. +func (mr *MockManagerMockRecorder) GetNameServerGroup(ctx, accountID, userID, nsGroupID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNameServerGroup", reflect.TypeOf((*MockManager)(nil).GetNameServerGroup), ctx, accountID, userID, nsGroupID) +} + +// GetNetworkMap mocks base method. +func (m *MockManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID) + ret0, _ := ret[0].(*types.NetworkMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNetworkMap indicates an expected call of GetNetworkMap. +func (mr *MockManagerMockRecorder) GetNetworkMap(ctx, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockManager)(nil).GetNetworkMap), ctx, peerID) +} + +// GetOrCreateAccountByPrivateDomain mocks base method. +func (m *MockManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrCreateAccountByPrivateDomain", ctx, initiatorId, domain) + ret0, _ := ret[0].(*types.Account) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetOrCreateAccountByPrivateDomain indicates an expected call of GetOrCreateAccountByPrivateDomain. +func (mr *MockManagerMockRecorder) GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrCreateAccountByPrivateDomain", reflect.TypeOf((*MockManager)(nil).GetOrCreateAccountByPrivateDomain), ctx, initiatorId, domain) +} + +// GetOrCreateAccountByUser mocks base method. +func (m *MockManager) GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOrCreateAccountByUser", ctx, userAuth) + ret0, _ := ret[0].(*types.Account) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrCreateAccountByUser indicates an expected call of GetOrCreateAccountByUser. +func (mr *MockManagerMockRecorder) GetOrCreateAccountByUser(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrCreateAccountByUser", reflect.TypeOf((*MockManager)(nil).GetOrCreateAccountByUser), ctx, userAuth) +} + +// GetOwnerInfo mocks base method. +func (m *MockManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOwnerInfo", ctx, accountId) + ret0, _ := ret[0].(*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOwnerInfo indicates an expected call of GetOwnerInfo. +func (mr *MockManagerMockRecorder) GetOwnerInfo(ctx, accountId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOwnerInfo", reflect.TypeOf((*MockManager)(nil).GetOwnerInfo), ctx, accountId) +} + +// GetPAT mocks base method. +func (m *MockManager) GetPAT(ctx context.Context, accountID, initiatorUserID, targetUserID, tokenID string) (*types.PersonalAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPAT", ctx, accountID, initiatorUserID, targetUserID, tokenID) + ret0, _ := ret[0].(*types.PersonalAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPAT indicates an expected call of GetPAT. +func (mr *MockManagerMockRecorder) GetPAT(ctx, accountID, initiatorUserID, targetUserID, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPAT", reflect.TypeOf((*MockManager)(nil).GetPAT), ctx, accountID, initiatorUserID, targetUserID, tokenID) +} + +// GetPeer mocks base method. +func (m *MockManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeer", ctx, accountID, peerID, userID) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeer indicates an expected call of GetPeer. +func (mr *MockManagerMockRecorder) GetPeer(ctx, accountID, peerID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeer", reflect.TypeOf((*MockManager)(nil).GetPeer), ctx, accountID, peerID, userID) +} + +// GetPeerGroups mocks base method. +func (m *MockManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerGroups", ctx, accountID, peerID) + ret0, _ := ret[0].([]*types.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeerGroups indicates an expected call of GetPeerGroups. +func (mr *MockManagerMockRecorder) GetPeerGroups(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerGroups", reflect.TypeOf((*MockManager)(nil).GetPeerGroups), ctx, accountID, peerID) +} + +// GetPeerJobByID mocks base method. +func (m *MockManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerJobByID", ctx, accountID, userID, peerID, jobID) + ret0, _ := ret[0].(*types.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeerJobByID indicates an expected call of GetPeerJobByID. +func (mr *MockManagerMockRecorder) GetPeerJobByID(ctx, accountID, userID, peerID, jobID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerJobByID", reflect.TypeOf((*MockManager)(nil).GetPeerJobByID), ctx, accountID, userID, peerID, jobID) +} + +// GetPeerNetwork mocks base method. +func (m *MockManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerNetwork", ctx, peerID) + ret0, _ := ret[0].(*types.Network) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeerNetwork indicates an expected call of GetPeerNetwork. +func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerNetwork", reflect.TypeOf((*MockManager)(nil).GetPeerNetwork), ctx, peerID) +} + +// GetPeers mocks base method. +func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter) + ret0, _ := ret[0].([]*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeers indicates an expected call of GetPeers. +func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter) +} + +// GetPolicy mocks base method. +func (m *MockManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPolicy", ctx, accountID, policyID, userID) + ret0, _ := ret[0].(*types.Policy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPolicy indicates an expected call of GetPolicy. +func (mr *MockManagerMockRecorder) GetPolicy(ctx, accountID, policyID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPolicy", reflect.TypeOf((*MockManager)(nil).GetPolicy), ctx, accountID, policyID, userID) +} + +// GetPostureChecks mocks base method. +func (m *MockManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPostureChecks", ctx, accountID, postureChecksID, userID) + ret0, _ := ret[0].(*posture.Checks) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPostureChecks indicates an expected call of GetPostureChecks. +func (mr *MockManagerMockRecorder) GetPostureChecks(ctx, accountID, postureChecksID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPostureChecks", reflect.TypeOf((*MockManager)(nil).GetPostureChecks), ctx, accountID, postureChecksID, userID) +} + +// GetRoute mocks base method. +func (m *MockManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoute", ctx, accountID, routeID, userID) + ret0, _ := ret[0].(*route.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoute indicates an expected call of GetRoute. +func (mr *MockManagerMockRecorder) GetRoute(ctx, accountID, routeID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoute", reflect.TypeOf((*MockManager)(nil).GetRoute), ctx, accountID, routeID, userID) +} + +// GetSetupKey mocks base method. +func (m *MockManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetupKey", ctx, accountID, userID, keyID) + ret0, _ := ret[0].(*types.SetupKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetupKey indicates an expected call of GetSetupKey. +func (mr *MockManagerMockRecorder) GetSetupKey(ctx, accountID, userID, keyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetupKey", reflect.TypeOf((*MockManager)(nil).GetSetupKey), ctx, accountID, userID, keyID) +} + +// GetStore mocks base method. +func (m *MockManager) GetStore() store.Store { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStore") + ret0, _ := ret[0].(store.Store) + return ret0 +} + +// GetStore indicates an expected call of GetStore. +func (mr *MockManagerMockRecorder) GetStore() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockManager)(nil).GetStore)) +} + +// GetUserByID mocks base method. +func (m *MockManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserByID", ctx, id) + ret0, _ := ret[0].(*types.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserByID indicates an expected call of GetUserByID. +func (mr *MockManagerMockRecorder) GetUserByID(ctx, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByID", reflect.TypeOf((*MockManager)(nil).GetUserByID), ctx, id) +} + +// GetUserFromUserAuth mocks base method. +func (m *MockManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserFromUserAuth", ctx, userAuth) + ret0, _ := ret[0].(*types.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserFromUserAuth indicates an expected call of GetUserFromUserAuth. +func (mr *MockManagerMockRecorder) GetUserFromUserAuth(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserFromUserAuth", reflect.TypeOf((*MockManager)(nil).GetUserFromUserAuth), ctx, userAuth) +} + +// GetUserIDByPeerKey mocks base method. +func (m *MockManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserIDByPeerKey", ctx, peerKey) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserIDByPeerKey indicates an expected call of GetUserIDByPeerKey. +func (mr *MockManagerMockRecorder) GetUserIDByPeerKey(ctx, peerKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserIDByPeerKey", reflect.TypeOf((*MockManager)(nil).GetUserIDByPeerKey), ctx, peerKey) +} + +// GetUserInviteInfo mocks base method. +func (m *MockManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserInviteInfo", ctx, token) + ret0, _ := ret[0].(*types.UserInviteInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserInviteInfo indicates an expected call of GetUserInviteInfo. +func (mr *MockManagerMockRecorder) GetUserInviteInfo(ctx, token interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserInviteInfo", reflect.TypeOf((*MockManager)(nil).GetUserInviteInfo), ctx, token) +} + +// GetUsersFromAccount mocks base method. +func (m *MockManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUsersFromAccount", ctx, accountID, userID) + ret0, _ := ret[0].(map[string]*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUsersFromAccount indicates an expected call of GetUsersFromAccount. +func (mr *MockManagerMockRecorder) GetUsersFromAccount(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsersFromAccount", reflect.TypeOf((*MockManager)(nil).GetUsersFromAccount), ctx, accountID, userID) +} + +// GetValidatedPeers mocks base method. +func (m *MockManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatedPeers", ctx, accountID) + ret0, _ := ret[0].(map[string]struct{}) + ret1, _ := ret[1].(map[string]string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetValidatedPeers indicates an expected call of GetValidatedPeers. +func (mr *MockManagerMockRecorder) GetValidatedPeers(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeers", reflect.TypeOf((*MockManager)(nil).GetValidatedPeers), ctx, accountID) +} + +// GroupAddPeer mocks base method. +func (m *MockManager) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupAddPeer", ctx, accountId, groupID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// GroupAddPeer indicates an expected call of GroupAddPeer. +func (mr *MockManagerMockRecorder) GroupAddPeer(ctx, accountId, groupID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupAddPeer", reflect.TypeOf((*MockManager)(nil).GroupAddPeer), ctx, accountId, groupID, peerID) +} + +// GroupDeletePeer mocks base method. +func (m *MockManager) GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupDeletePeer", ctx, accountId, groupID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// GroupDeletePeer indicates an expected call of GroupDeletePeer. +func (mr *MockManagerMockRecorder) GroupDeletePeer(ctx, accountId, groupID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupDeletePeer", reflect.TypeOf((*MockManager)(nil).GroupDeletePeer), ctx, accountId, groupID, peerID) +} + +// GroupValidation mocks base method. +func (m *MockManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupValidation", ctx, accountId, groups) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GroupValidation indicates an expected call of GroupValidation. +func (mr *MockManagerMockRecorder) GroupValidation(ctx, accountId, groups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupValidation", reflect.TypeOf((*MockManager)(nil).GroupValidation), ctx, accountId, groups) +} + +// InviteUser mocks base method. +func (m *MockManager) InviteUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InviteUser", ctx, accountID, initiatorUserID, targetUserID) + ret0, _ := ret[0].(error) + return ret0 +} + +// InviteUser indicates an expected call of InviteUser. +func (mr *MockManagerMockRecorder) InviteUser(ctx, accountID, initiatorUserID, targetUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteUser", reflect.TypeOf((*MockManager)(nil).InviteUser), ctx, accountID, initiatorUserID, targetUserID) +} + +// ListNameServerGroups mocks base method. +func (m *MockManager) ListNameServerGroups(ctx context.Context, accountID, userID string) ([]*dns.NameServerGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListNameServerGroups", ctx, accountID, userID) + ret0, _ := ret[0].([]*dns.NameServerGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListNameServerGroups indicates an expected call of ListNameServerGroups. +func (mr *MockManagerMockRecorder) ListNameServerGroups(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNameServerGroups", reflect.TypeOf((*MockManager)(nil).ListNameServerGroups), ctx, accountID, userID) +} + +// ListPolicies mocks base method. +func (m *MockManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListPolicies", ctx, accountID, userID) + ret0, _ := ret[0].([]*types.Policy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListPolicies indicates an expected call of ListPolicies. +func (mr *MockManagerMockRecorder) ListPolicies(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPolicies", reflect.TypeOf((*MockManager)(nil).ListPolicies), ctx, accountID, userID) +} + +// ListPostureChecks mocks base method. +func (m *MockManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListPostureChecks", ctx, accountID, userID) + ret0, _ := ret[0].([]*posture.Checks) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListPostureChecks indicates an expected call of ListPostureChecks. +func (mr *MockManagerMockRecorder) ListPostureChecks(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPostureChecks", reflect.TypeOf((*MockManager)(nil).ListPostureChecks), ctx, accountID, userID) +} + +// ListRoutes mocks base method. +func (m *MockManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRoutes", ctx, accountID, userID) + ret0, _ := ret[0].([]*route.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRoutes indicates an expected call of ListRoutes. +func (mr *MockManagerMockRecorder) ListRoutes(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRoutes", reflect.TypeOf((*MockManager)(nil).ListRoutes), ctx, accountID, userID) +} + +// ListSetupKeys mocks base method. +func (m *MockManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSetupKeys", ctx, accountID, userID) + ret0, _ := ret[0].([]*types.SetupKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSetupKeys indicates an expected call of ListSetupKeys. +func (mr *MockManagerMockRecorder) ListSetupKeys(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSetupKeys", reflect.TypeOf((*MockManager)(nil).ListSetupKeys), ctx, accountID, userID) +} + +// ListUserInvites mocks base method. +func (m *MockManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListUserInvites", ctx, accountID, initiatorUserID) + ret0, _ := ret[0].([]*types.UserInvite) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListUserInvites indicates an expected call of ListUserInvites. +func (mr *MockManagerMockRecorder) ListUserInvites(ctx, accountID, initiatorUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserInvites", reflect.TypeOf((*MockManager)(nil).ListUserInvites), ctx, accountID, initiatorUserID) +} + +// ListUsers mocks base method. +func (m *MockManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListUsers", ctx, accountID) + ret0, _ := ret[0].([]*types.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListUsers indicates an expected call of ListUsers. +func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUsers", reflect.TypeOf((*MockManager)(nil).ListUsers), ctx, accountID) +} + +// LoginPeer mocks base method. +func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoginPeer", ctx, login) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// LoginPeer indicates an expected call of LoginPeer. +func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login) +} + +// MarkPeerConnected mocks base method. +func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkPeerConnected indicates an expected call of MarkPeerConnected. +func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime) +} + +// OnPeerDisconnected mocks base method. +func (m *MockManager) OnPeerDisconnected(ctx context.Context, accountID, peerPubKey string, streamStartTime time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerPubKey, streamStartTime) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerDisconnected indicates an expected call of OnPeerDisconnected. +func (mr *MockManagerMockRecorder) OnPeerDisconnected(ctx, accountID, peerPubKey, streamStartTime interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockManager)(nil).OnPeerDisconnected), ctx, accountID, peerPubKey, streamStartTime) +} + +// RegenerateUserInvite mocks base method. +func (m *MockManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegenerateUserInvite", ctx, accountID, initiatorUserID, inviteID, expiresIn) + ret0, _ := ret[0].(*types.UserInvite) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegenerateUserInvite indicates an expected call of RegenerateUserInvite. +func (mr *MockManagerMockRecorder) RegenerateUserInvite(ctx, accountID, initiatorUserID, inviteID, expiresIn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegenerateUserInvite", reflect.TypeOf((*MockManager)(nil).RegenerateUserInvite), ctx, accountID, initiatorUserID, inviteID, expiresIn) +} + +// RejectUser mocks base method. +func (m *MockManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RejectUser", ctx, accountID, initiatorUserID, targetUserID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RejectUser indicates an expected call of RejectUser. +func (mr *MockManagerMockRecorder) RejectUser(ctx, accountID, initiatorUserID, targetUserID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RejectUser", reflect.TypeOf((*MockManager)(nil).RejectUser), ctx, accountID, initiatorUserID, targetUserID) +} + +// SaveDNSSettings mocks base method. +func (m *MockManager) SaveDNSSettings(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveDNSSettings", ctx, accountID, userID, dnsSettingsToSave) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveDNSSettings indicates an expected call of SaveDNSSettings. +func (mr *MockManagerMockRecorder) SaveDNSSettings(ctx, accountID, userID, dnsSettingsToSave interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveDNSSettings", reflect.TypeOf((*MockManager)(nil).SaveDNSSettings), ctx, accountID, userID, dnsSettingsToSave) +} + +// SaveNameServerGroup mocks base method. +func (m *MockManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *dns.NameServerGroup) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveNameServerGroup", ctx, accountID, userID, nsGroupToSave) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveNameServerGroup indicates an expected call of SaveNameServerGroup. +func (mr *MockManagerMockRecorder) SaveNameServerGroup(ctx, accountID, userID, nsGroupToSave interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNameServerGroup", reflect.TypeOf((*MockManager)(nil).SaveNameServerGroup), ctx, accountID, userID, nsGroupToSave) +} + +// SaveOrAddUser mocks base method. +func (m *MockManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOrAddUser", ctx, accountID, initiatorUserID, update, addIfNotExists) + ret0, _ := ret[0].(*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveOrAddUser indicates an expected call of SaveOrAddUser. +func (mr *MockManagerMockRecorder) SaveOrAddUser(ctx, accountID, initiatorUserID, update, addIfNotExists interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOrAddUser", reflect.TypeOf((*MockManager)(nil).SaveOrAddUser), ctx, accountID, initiatorUserID, update, addIfNotExists) +} + +// SaveOrAddUsers mocks base method. +func (m *MockManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOrAddUsers", ctx, accountID, initiatorUserID, updates, addIfNotExists) + ret0, _ := ret[0].([]*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveOrAddUsers indicates an expected call of SaveOrAddUsers. +func (mr *MockManagerMockRecorder) SaveOrAddUsers(ctx, accountID, initiatorUserID, updates, addIfNotExists interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOrAddUsers", reflect.TypeOf((*MockManager)(nil).SaveOrAddUsers), ctx, accountID, initiatorUserID, updates, addIfNotExists) +} + +// SavePolicy mocks base method. +func (m *MockManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SavePolicy", ctx, accountID, userID, policy, create) + ret0, _ := ret[0].(*types.Policy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SavePolicy indicates an expected call of SavePolicy. +func (mr *MockManagerMockRecorder) SavePolicy(ctx, accountID, userID, policy, create interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePolicy", reflect.TypeOf((*MockManager)(nil).SavePolicy), ctx, accountID, userID, policy, create) +} + +// SavePostureChecks mocks base method. +func (m *MockManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SavePostureChecks", ctx, accountID, userID, postureChecks, create) + ret0, _ := ret[0].(*posture.Checks) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SavePostureChecks indicates an expected call of SavePostureChecks. +func (mr *MockManagerMockRecorder) SavePostureChecks(ctx, accountID, userID, postureChecks, create interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockManager)(nil).SavePostureChecks), ctx, accountID, userID, postureChecks, create) +} + +// SaveRoute mocks base method. +func (m *MockManager) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveRoute", ctx, accountID, userID, route) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveRoute indicates an expected call of SaveRoute. +func (mr *MockManagerMockRecorder) SaveRoute(ctx, accountID, userID, route interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveRoute", reflect.TypeOf((*MockManager)(nil).SaveRoute), ctx, accountID, userID, route) +} + +// SaveSetupKey mocks base method. +func (m *MockManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveSetupKey", ctx, accountID, key, userID) + ret0, _ := ret[0].(*types.SetupKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveSetupKey indicates an expected call of SaveSetupKey. +func (mr *MockManagerMockRecorder) SaveSetupKey(ctx, accountID, key, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveSetupKey", reflect.TypeOf((*MockManager)(nil).SaveSetupKey), ctx, accountID, key, userID) +} + +// SaveUser mocks base method. +func (m *MockManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveUser", ctx, accountID, initiatorUserID, update) + ret0, _ := ret[0].(*types.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveUser indicates an expected call of SaveUser. +func (mr *MockManagerMockRecorder) SaveUser(ctx, accountID, initiatorUserID, update interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveUser", reflect.TypeOf((*MockManager)(nil).SaveUser), ctx, accountID, initiatorUserID, update) +} + +// SetServiceManager mocks base method. +func (m *MockManager) SetServiceManager(serviceManager service.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetServiceManager", serviceManager) +} + +// SetServiceManager indicates an expected call of SetServiceManager. +func (mr *MockManagerMockRecorder) SetServiceManager(serviceManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServiceManager", reflect.TypeOf((*MockManager)(nil).SetServiceManager), serviceManager) +} + +// StoreEvent mocks base method. +func (m *MockManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StoreEvent", ctx, initiatorID, targetID, accountID, activityID, meta) +} + +// StoreEvent indicates an expected call of StoreEvent. +func (mr *MockManagerMockRecorder) StoreEvent(ctx, initiatorID, targetID, accountID, activityID, meta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreEvent", reflect.TypeOf((*MockManager)(nil).StoreEvent), ctx, initiatorID, targetID, accountID, activityID, meta) +} + +// SyncAndMarkPeer mocks base method. +func (m *MockManager) SyncAndMarkPeer(ctx context.Context, accountID, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SyncAndMarkPeer", ctx, accountID, peerPubKey, meta, realIP, syncTime) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// SyncAndMarkPeer indicates an expected call of SyncAndMarkPeer. +func (mr *MockManagerMockRecorder) SyncAndMarkPeer(ctx, accountID, peerPubKey, meta, realIP, syncTime interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncAndMarkPeer", reflect.TypeOf((*MockManager)(nil).SyncAndMarkPeer), ctx, accountID, peerPubKey, meta, realIP, syncTime) +} + +// SyncPeer mocks base method. +func (m *MockManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SyncPeer", ctx, sync, accountID) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// SyncPeer indicates an expected call of SyncPeer. +func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeer", reflect.TypeOf((*MockManager)(nil).SyncPeer), ctx, sync, accountID) +} + +// SyncPeerMeta mocks base method. +func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta) + ret0, _ := ret[0].(error) + return ret0 +} + +// SyncPeerMeta indicates an expected call of SyncPeerMeta. +func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta) +} + +// SyncUserJWTGroups mocks base method. +func (m *MockManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SyncUserJWTGroups", ctx, userAuth) + ret0, _ := ret[0].(error) + return ret0 +} + +// SyncUserJWTGroups indicates an expected call of SyncUserJWTGroups. +func (mr *MockManagerMockRecorder) SyncUserJWTGroups(ctx, userAuth interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncUserJWTGroups", reflect.TypeOf((*MockManager)(nil).SyncUserJWTGroups), ctx, userAuth) +} + +// UpdateAccountOnboarding mocks base method. +func (m *MockManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountOnboarding", ctx, accountID, userID, newOnboarding) + ret0, _ := ret[0].(*types.AccountOnboarding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateAccountOnboarding indicates an expected call of UpdateAccountOnboarding. +func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userID, newOnboarding interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountOnboarding", reflect.TypeOf((*MockManager)(nil).UpdateAccountOnboarding), ctx, accountID, userID, newOnboarding) +} + +// UpdateAccountPeers mocks base method. +func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) +} + +// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. +func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID) +} + +// UpdateAccountSettings mocks base method. +func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountSettings", ctx, accountID, userID, newSettings) + ret0, _ := ret[0].(*types.Settings) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateAccountSettings indicates an expected call of UpdateAccountSettings. +func (mr *MockManagerMockRecorder) UpdateAccountSettings(ctx, accountID, userID, newSettings interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountSettings", reflect.TypeOf((*MockManager)(nil).UpdateAccountSettings), ctx, accountID, userID, newSettings) +} + +// UpdateGroup mocks base method. +func (m *MockManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateGroup", ctx, accountID, userID, group) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateGroup indicates an expected call of UpdateGroup. +func (mr *MockManagerMockRecorder) UpdateGroup(ctx, accountID, userID, group interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroup", reflect.TypeOf((*MockManager)(nil).UpdateGroup), ctx, accountID, userID, group) +} + +// UpdateGroups mocks base method. +func (m *MockManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateGroups", ctx, accountID, userID, newGroups) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateGroups indicates an expected call of UpdateGroups. +func (mr *MockManagerMockRecorder) UpdateGroups(ctx, accountID, userID, newGroups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockManager)(nil).UpdateGroups), ctx, accountID, userID, newGroups) +} + +// UpdateIdentityProvider mocks base method. +func (m *MockManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateIdentityProvider", ctx, accountID, idpID, userID, idp) + ret0, _ := ret[0].(*types.IdentityProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateIdentityProvider indicates an expected call of UpdateIdentityProvider. +func (mr *MockManagerMockRecorder) UpdateIdentityProvider(ctx, accountID, idpID, userID, idp interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateIdentityProvider", reflect.TypeOf((*MockManager)(nil).UpdateIdentityProvider), ctx, accountID, idpID, userID, idp) +} + +// UpdateIntegratedValidator mocks base method. +func (m *MockManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateIntegratedValidator", ctx, accountID, userID, validator, groups) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateIntegratedValidator indicates an expected call of UpdateIntegratedValidator. +func (mr *MockManagerMockRecorder) UpdateIntegratedValidator(ctx, accountID, userID, validator, groups interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateIntegratedValidator", reflect.TypeOf((*MockManager)(nil).UpdateIntegratedValidator), ctx, accountID, userID, validator, groups) +} + +// UpdatePeer mocks base method. +func (m *MockManager) UpdatePeer(ctx context.Context, accountID, userID string, p *peer.Peer) (*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePeer", ctx, accountID, userID, p) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdatePeer indicates an expected call of UpdatePeer. +func (mr *MockManagerMockRecorder) UpdatePeer(ctx, accountID, userID, p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeer", reflect.TypeOf((*MockManager)(nil).UpdatePeer), ctx, accountID, userID, p) +} + +// UpdatePeerIP mocks base method. +func (m *MockManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePeerIP", ctx, accountID, userID, peerID, newIP) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePeerIP indicates an expected call of UpdatePeerIP. +func (mr *MockManagerMockRecorder) UpdatePeerIP(ctx, accountID, userID, peerID, newIP interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeerIP", reflect.TypeOf((*MockManager)(nil).UpdatePeerIP), ctx, accountID, userID, peerID, newIP) +} + +// UpdateToPrimaryAccount mocks base method. +func (m *MockManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateToPrimaryAccount", ctx, accountId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateToPrimaryAccount indicates an expected call of UpdateToPrimaryAccount. +func (mr *MockManagerMockRecorder) UpdateToPrimaryAccount(ctx, accountId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToPrimaryAccount", reflect.TypeOf((*MockManager)(nil).UpdateToPrimaryAccount), ctx, accountId) +} + +// UpdateUserPassword mocks base method. +func (m *MockManager) UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID, oldPassword, newPassword string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUserPassword", ctx, accountID, currentUserID, targetUserID, oldPassword, newPassword) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateUserPassword indicates an expected call of UpdateUserPassword. +func (mr *MockManagerMockRecorder) UpdateUserPassword(ctx, accountID, currentUserID, targetUserID, oldPassword, newPassword interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserPassword", reflect.TypeOf((*MockManager)(nil).UpdateUserPassword), ctx, accountID, currentUserID, targetUserID, oldPassword, newPassword) +} diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index fa6c45856..e1672c2d0 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -86,7 +86,14 @@ func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, acco result := &AccountResult{Account: account, Err: err} for _, req := range requests { - req.ResultChan <- result + if account != nil { + // Shallow copy the account so each goroutine gets its own struct value. + // This prevents data races when callers mutate fields like Policies. + accountCopy := *account + req.ResultChan <- &AccountResult{Account: &accountCopy, Err: err} + } else { + req.ResultChan <- result + } close(req.ResultChan) } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 1cc0c9571..4453d064e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -19,18 +19,25 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/shared/management/status" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -1802,12 +1809,19 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, - Services: []*reverseproxy.Service{ + Services: []*service.Service{ { ID: "service1", Name: "test-service", AccountID: "account1", - Targets: []*reverseproxy.Target{}, + Targets: []*service.Target{}, + }, + }, + Domains: []*domain.Domain{ + { + ID: "domain1", + Domain: "test.com", + AccountID: "account1", }, }, NetworkMapCache: &types.NetworkMapBuilder{}, @@ -3112,17 +3126,33 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU permissionsManager := permissions.NewManager(store) peersManager := peers.NewManager(store, permissionsManager) + proxyManager := proxy.NewMockManager(ctrl) + proxyManager.EXPECT(). + CleanupStale(gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + ctx := context.Background() - updateManager := update_channel.NewPeersUpdateManager(metrics) - requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil)) + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) + if err != nil { + return nil, nil, err + } + + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) + if err != nil { + return nil, nil, err + } + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) return manager, updateManager, nil } @@ -3951,3 +3981,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange") } } + +func TestUpdateUserAuthWithSingleMode(t *testing.T) { + t.Run("sets defaults and overrides domain from store", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("account-1", nil) + mockStore.EXPECT(). + GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1"). + Return("real-domain.com", "private", nil) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "real-domain.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", nil) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "fallback.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", status.Errorf(status.NotFound, "no accounts")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "fallback.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", status.Errorf(status.Internal, "db down")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.Error(t, err) + assert.Contains(t, err.Error(), "db down") + // Defaults should still be set before error path + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("account-1", nil) + mockStore.EXPECT(). + GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1"). + Return("", "", status.Errorf(status.Internal, "query failed")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.Error(t, err) + assert.Contains(t, err.Error(), "query failed") + }) +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e1b7e5300..ddc3e00c3 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -208,6 +208,30 @@ const ( ServiceUpdated Activity = 109 ServiceDeleted Activity = 110 + // PeerServiceExposed indicates that a peer exposed a service via the reverse proxy + PeerServiceExposed Activity = 111 + // PeerServiceUnexposed indicates that a peer-exposed service was removed + PeerServiceUnexposed Activity = 112 + // PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration + PeerServiceExposeExpired Activity = 113 + + // AccountPeerExposeEnabled indicates that a user enabled peer expose for the account + AccountPeerExposeEnabled Activity = 114 + // AccountPeerExposeDisabled indicates that a user disabled peer expose for the account + AccountPeerExposeDisabled Activity = 115 + + // AccountAutoUpdateAlwaysEnabled indicates that a user enabled always auto-update for the account + AccountAutoUpdateAlwaysEnabled Activity = 116 + // AccountAutoUpdateAlwaysDisabled indicates that a user disabled always auto-update for the account + AccountAutoUpdateAlwaysDisabled Activity = 117 + + // DomainAdded indicates that a user added a custom domain + DomainAdded Activity = 118 + // DomainDeleted indicates that a user deleted a custom domain + DomainDeleted Activity = 119 + // DomainValidated indicates that a custom domain was validated + DomainValidated Activity = 120 + AccountDeleted Activity = 99999 ) @@ -320,6 +344,8 @@ var activityMap = map[Activity]Code{ UserCreated: {"User created", "user.create"}, AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, + AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"}, + AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"}, IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, @@ -345,6 +371,17 @@ var activityMap = map[Activity]Code{ ServiceCreated: {"Service created", "service.create"}, ServiceUpdated: {"Service updated", "service.update"}, ServiceDeleted: {"Service deleted", "service.delete"}, + + PeerServiceExposed: {"Peer exposed service", "service.peer.expose"}, + PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"}, + PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"}, + + AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"}, + AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"}, + + DomainAdded: {"Domain added", "domain.add"}, + DomainDeleted: {"Domain deleted", "domain.delete"}, + DomainValidated: {"Domain validated", "domain.validate"}, } // StringCode returns a string code of the activity diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index db614d0cd..73e8e295c 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) { switch storeEngine { case types.SqliteStoreEngine: - dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB)) + dbFile := eventSinkDB + if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" { + dbFile = envFile + } + connStr := dbFile + if !filepath.IsAbs(dbFile) { + connStr = filepath.Join(dataDir, dbFile) + } + dialector = sqlite.Open(connStr) case types.PostgresStoreEngine: dsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { diff --git a/management/server/activity/store/sql_store_idp_migration.go b/management/server/activity/store/sql_store_idp_migration.go new file mode 100644 index 000000000..1b3a9ecd9 --- /dev/null +++ b/management/server/activity/store/sql_store_idp_migration.go @@ -0,0 +1,61 @@ +package store + +// This file contains migration-only methods on Store. +// They satisfy the migration.MigrationEventStore interface via duck typing. +// Delete this file when migration tooling is no longer needed. + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp/migration" +) + +// CheckSchema verifies that all tables and columns required by the migration exist in the event database. +func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError { + migrator := store.db.Migrator() + var errs []migration.SchemaError + + for _, check := range checks { + if !migrator.HasTable(check.Table) { + errs = append(errs, migration.SchemaError{Table: check.Table}) + continue + } + for _, col := range check.Columns { + if !migrator.HasColumn(check.Table, col) { + errs = append(errs, migration.SchemaError{Table: check.Table, Column: col}) + } + } + } + + return errs +} + +// UpdateUserID updates all references to oldUserID in events and deleted_users tables. +func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error { + return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&activity.Event{}). + Where("initiator_id = ?", oldUserID). + Update("initiator_id", newUserID).Error; err != nil { + return fmt.Errorf("update events.initiator_id: %w", err) + } + + if err := tx.Model(&activity.Event{}). + Where("target_id = ?", oldUserID). + Update("target_id", newUserID).Error; err != nil { + return fmt.Errorf("update events.target_id: %w", err) + } + + // Raw exec: GORM can't update a PK via Model().Update() + if err := tx.Exec( + "UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID, + ).Error; err != nil { + return fmt.Errorf("update deleted_users.id: %w", err) + } + + return nil + }) +} diff --git a/management/server/activity/store/sql_store_idp_migration_test.go b/management/server/activity/store/sql_store_idp_migration_test.go new file mode 100644 index 000000000..98b6e1327 --- /dev/null +++ b/management/server/activity/store/sql_store_idp_migration_test.go @@ -0,0 +1,161 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/util/crypt" +) + +func TestUpdateUserID(t *testing.T) { + ctx := context.Background() + + newStore := func(t *testing.T) *Store { + t.Helper() + key, _ := crypt.GenerateKey() + s, err := NewSqlStore(ctx, t.TempDir(), key) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { s.Close(ctx) }) //nolint + return s + } + + t.Run("updates initiator_id in events", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "old-user", + TargetID: "some-peer", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "new-user", result[0].InitiatorID) + }) + + t.Run("updates target_id in events", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "some-admin", + TargetID: "old-user", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "new-user", result[0].TargetID) + }) + + t.Run("updates deleted_users id", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + // Save an event with email/name meta to create a deleted_users row for "old-user" + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "admin", + TargetID: "old-user", + AccountID: accountID, + Meta: map[string]any{ + "email": "user@example.com", + "name": "Test User", + }, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "old-user", "new-user") + assert.NoError(t, err) + + // Save another event referencing new-user with email/name meta. + // This should upsert (not conflict) because the PK was already migrated. + _, err = store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "admin", + TargetID: "new-user", + AccountID: accountID, + Meta: map[string]any{ + "email": "user@example.com", + "name": "Test User", + }, + }) + assert.NoError(t, err) + + // The deleted user info should be retrievable via Get (joined on target_id) + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 2) + for _, ev := range result { + assert.Equal(t, "new-user", ev.TargetID) + } + }) + + t.Run("no-op when old user ID does not exist", func(t *testing.T) { + store := newStore(t) + + err := store.UpdateUserID(ctx, "nonexistent-user", "new-user") + assert.NoError(t, err) + }) + + t.Run("only updates matching user leaves others unchanged", func(t *testing.T) { + store := newStore(t) + accountID := "account_1" + + _, err := store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "user-a", + TargetID: "peer-1", + AccountID: accountID, + }) + assert.NoError(t, err) + + _, err = store.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "user-b", + TargetID: "peer-2", + AccountID: accountID, + }) + assert.NoError(t, err) + + err = store.UpdateUserID(ctx, "user-a", "user-a-new") + assert.NoError(t, err) + + result, err := store.Get(ctx, accountID, 0, 10, false) + assert.NoError(t, err) + assert.Len(t, result, 2) + + for _, ev := range result { + if ev.TargetID == "peer-1" { + assert.Equal(t, "user-a-new", ev.InitiatorID) + } else { + assert.Equal(t, "user-b", ev.InitiatorID) + } + } + }) +} diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 76cc750b6..27346a604 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -33,15 +33,20 @@ type manager struct { extractor *nbjwt.ClaimsExtractor } -func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager { - // @note if invalid/missing parameters are sent the validator will instantiate - // but it will fail when validating and parsing the token - jwtValidator := nbjwt.NewValidator( - issuer, - allAudiences, - keysLocation, - idpRefreshKeys, - ) +func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager { + var jwtValidator *nbjwt.Validator + if keyFetcher != nil { + jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher) + } else { + // @note if invalid/missing parameters are sent the validator will instantiate + // but it will fail when validating and parsing the token + jwtValidator = nbjwt.NewValidator( + issuer, + allAudiences, + keysLocation, + idpRefreshKeys, + ) + } claimsExtractor := nbjwt.NewClaimsExtractor( nbjwt.WithAudience(audience), diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index b9f091b1e..469737f47 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) user, pat, _, _, err := manager.GetPATInfo(context.Background(), token) if err != nil { @@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) err = manager.MarkPATUsed(context.Background(), "tokenId") if err != nil { @@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { // these tests only assert groups are parsed from token as per account settings token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}) - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) t.Run("JWT groups disabled", func(t *testing.T) { userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) @@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { keyId := "test-key" // note, we can use a nil store because ValidateAndParseToken does not use it in it's flow - manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false) + manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil) customClaim := func(name string) string { return fmt.Sprintf("%s/%s", audience, name) diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 54b0242de..2ca8e8603 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -17,12 +17,24 @@ import ( // RedisStoreEnvVar is the environment variable that determines if a redis store should be used. // The value should follow redis URL format. https://github.com/redis/redis-specifications/blob/master/uri/redis.txt -const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" +const RedisStoreEnvVar = "NB_CACHE_REDIS_ADDRESS" + +// legacyIdPCacheRedisEnvVar is the previous environment variable used for IDP cache. +const legacyIdPCacheRedisEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" + +const ( + // DefaultStoreMaxTimeout is the default max timeout for the shared cache store. + DefaultStoreMaxTimeout = 7 * 24 * time.Hour + // DefaultStoreCleanupInterval is the default cleanup interval for the shared cache store. + DefaultStoreCleanupInterval = 30 * time.Minute + // DefaultStoreMaxConn is the default max connections for the shared cache store. + DefaultStoreMaxConn = 1000 +) // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { - redisAddr := os.Getenv(RedisStoreEnvVar) + redisAddr := GetAddrFromEnv() if redisAddr != "" { return getRedisStore(ctx, redisAddr, maxConn) } @@ -30,6 +42,15 @@ func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, ma return gocache_store.NewGoCache(goc), nil } +// GetAddrFromEnv returns the redis address from the environment variable RedisStoreEnvVar or its legacy counterpart. +func GetAddrFromEnv() string { + addr := os.Getenv(RedisStoreEnvVar) + if addr == "" { + addr = os.Getenv(legacyIdPCacheRedisEnvVar) + } + return addr +} + func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index bd0755d0d..0e37a3b22 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" @@ -225,11 +226,17 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index c0179a1c4..0af3ce2f6 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -44,6 +44,12 @@ type Record struct { GeonameID uint `maxminddb:"geoname_id"` ISOCode string `maxminddb:"iso_code"` } `maxminddb:"country"` + Subdivisions []struct { + ISOCode string `maxminddb:"iso_code"` + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"subdivisions"` } type City struct { @@ -124,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() + if gl.db == nil { + return nil, fmt.Errorf("geolocation database is not available") + } + var record Record err := gl.db.Lookup(ip, &record) if err != nil { @@ -167,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er func (gl *geolocationImpl) Stop() error { close(gl.stopCh) - if gl.db != nil { - if err := gl.db.Close(); err != nil { + + gl.mux.Lock() + db := gl.db + gl.db = nil + gl.mux.Unlock() + + if db != nil { + if err := db.Close(); err != nil { return err } } diff --git a/management/server/group.go b/management/server/group.go index 9fc8db120..7b5b9b86c 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err + } return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) } @@ -425,6 +428,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var groupIDsToDelete []string var deletedGroups []*types.Group + extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) @@ -433,7 +441,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us continue } - if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil { allErrors = errors.Join(allErrors, err) continue } @@ -621,7 +629,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st return nil } -func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == types.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) @@ -641,6 +649,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty return &GroupLinkError{"network resource", group.Resources[0].ID} } + if slices.Contains(flowGroups, group.ID) { + return &GroupLinkError{"settings", "traffic event logging"} + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } diff --git a/management/server/group_test.go b/management/server/group_test.go index dba917dbb..fa818e532 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -26,6 +27,7 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" peer2 "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" @@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } +func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + settingsMock := settings.NewMockManager(ctrl) + settingsMock.EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil). + AnyTimes() + settingsMock.EXPECT(). + UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(false, nil). + AnyTimes() + am.settingsManager = settingsMock + + _, account, err := initTestGroupAccount(am) + require.NoError(t, err) + + grp := &types.Group{ + ID: "grp-for-flow", + AccountID: account.Id, + Name: "Group for flow", + Issued: types.GroupIssuedAPI, + Peers: make([]string, 0), + } + require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp)) + + err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow") + require.Error(t, err) + + var gErr *GroupLinkError + require.ErrorAs(t, err, &gErr) + assert.Equal(t, "settings", gErr.Resource) + assert.Equal(t, "traffic event logging", gErr.Name) + + group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID) + require.NoError(t, err) + assert.NotNil(t, group) + + regularGrp := &types.Group{ + ID: "grp-regular", + AccountID: account.Id, + Name: "Regular group", + Issued: types.GroupIssuedAPI, + Peers: make([]string, 0), + } + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp) + require.NoError(t, err) + + err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"}) + require.Error(t, err) + + group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID) + require.NoError(t, err) + assert.NotNil(t, group) + + _, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID) + assert.Error(t, err) +} + func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { accountID := "testingAcc" domain := "example.com" @@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) { permissionsManager := permissions.NewManager(manager.Store) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) - resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager) + resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager) networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 9d2384cae..ad36b9d46 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -17,9 +17,9 @@ import ( "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" @@ -73,7 +73,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -173,10 +173,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) - if reverseProxyManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + if serviceManager != nil && reverseProxyDomainManager != nil { + reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } - // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 122c061ce..cc5567e3d 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -168,6 +168,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { } func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) { + if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 { + return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group") + } + returnSettings := &types.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), @@ -175,6 +179,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), + + PeerExposeEnabled: req.Settings.PeerExposeEnabled, + PeerExposeGroups: req.Settings.PeerExposeGroups, } if req.Settings.Extra != nil { @@ -218,6 +225,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS return nil, fmt.Errorf("invalid AutoUpdateVersion") } } + if req.Settings.AutoUpdateAlways != nil { + returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways + } return returnSettings, nil } @@ -336,9 +346,12 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, + PeerExposeEnabled: settings.PeerExposeEnabled, + PeerExposeGroups: settings.PeerExposeGroups, LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, + AutoUpdateAlways: &settings.AutoUpdateAlways, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 6cbd5908d..739dfe2f6 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -146,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -171,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -196,6 +199,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -221,6 +225,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -246,6 +251,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 56ccc9d0b..f8d161a87 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { groupName := r.URL.Query().Get("name") if groupName != "" { // Get single group by name - group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 458a15c11..c7b4cbcdd 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return groups, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index c311a29fe..ce9efb78d 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = networkID router.AccountID = accountID router.Enabled = true + + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) @@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.ID = mux.Vars(r)["routerId"] router.AccountID = accountID + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 6a1b144f6..c99acab63 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -18,9 +18,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -190,7 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) usersManager := users.NewManager(testStore) @@ -205,12 +211,14 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager, + nil, ) - proxyService.SetProxyManager(&testServiceManager{store: testStore}) + proxyService.SetServiceManager(&testServiceManager{store: testStore}) handler := NewAuthCallbackHandler(proxyService, nil) @@ -239,12 +247,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store pubKey := base64.StdEncoding.EncodeToString(pub) privKey := base64.StdEncoding.EncodeToString(priv) - testProxy := &reverseproxy.Service{ + testProxy := &service.Service{ ID: "testProxyId", AccountID: "testAccountId", Name: "Test Proxy", Domain: "test-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -254,8 +262,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"allowedGroupId"}, }, @@ -265,12 +273,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store } require.NoError(t, testStore.CreateService(ctx, testProxy)) - restrictedProxy := &reverseproxy.Service{ + restrictedProxy := &service.Service{ ID: "restrictedProxyId", AccountID: "testAccountId", Name: "Restricted Proxy", Domain: "restricted-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -280,8 +288,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"restrictedGroupId"}, }, @@ -291,12 +299,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store } require.NoError(t, testStore.CreateService(ctx, restrictedProxy)) - noAuthProxy := &reverseproxy.Service{ + noAuthProxy := &service.Service{ ID: "noAuthProxyId", AccountID: "testAccountId", Name: "No Auth Proxy", Domain: "no-auth-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -306,8 +314,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: false, }, }, @@ -357,19 +365,23 @@ type testServiceManager struct { store store.Store } -func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + +func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) { return nil, nil } -func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { +func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) { return nil, nil } -func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } @@ -381,7 +393,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri return nil } -func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { +func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error { return nil } @@ -393,15 +405,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error return nil } -func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetServices(ctx, store.LockingStrengthNone) } -func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { +func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) } -func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } @@ -409,6 +421,24 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri return "", nil } +func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return nil, nil +} + +func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testServiceManager) StartExposeReaper(_ context.Context) {} + +func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go index 9447704cb..ddece7152 100644 --- a/management/server/http/middleware/bypass/bypass.go +++ b/management/server/http/middleware/bypass/bypass.go @@ -51,19 +51,28 @@ func GetList() []string { // This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { byPassMutex.RLock() - defer byPassMutex.RUnlock() - + var matched bool for bypassPath := range bypassPaths { - matched, err := path.Match(bypassPath, requestPath) + m, err := path.Match(bypassPath, requestPath) if err != nil { - log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) + list := make([]string, 0, len(bypassPaths)) + for k := range bypassPaths { + list = append(list, k) + } + log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %v: %v", bypassPath, requestPath, list, err) continue } - if matched { - h.ServeHTTP(w, r) - return true + if m { + matched = true + break } } + byPassMutex.RUnlock() + + if matched { + h.ServeHTTP(w, r) + return true + } return false } diff --git a/management/server/http/testing/integration/accounts_handler_integration_test.go b/management/server/http/testing/integration/accounts_handler_integration_test.go new file mode 100644 index 000000000..511730ee5 --- /dev/null +++ b/management/server/http/testing/integration/accounts_handler_integration_test.go @@ -0,0 +1,238 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Accounts_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all accounts", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Account{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + account := got[0] + assert.Equal(t, "test.com", account.Domain) + assert.Equal(t, "private", account.DomainCategory) + assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled) + assert.Equal(t, 86400, account.Settings.PeerLoginExpiration) + assert.Equal(t, false, account.Settings.RegularUsersViewBlocked) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Accounts_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + trueVal := true + falseVal := false + + tt := []struct { + name string + expectedStatus int + requestBody *api.AccountRequest + verifyResponse func(t *testing.T, account *api.Account) + verifyDB func(t *testing.T, account *types.Account) + }{ + { + name: "Disable peer login expiration", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: false, + PeerLoginExpiration: 86400, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled) + }, + }, + { + name: "Update peer login expiration to 48h", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 172800, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, 172800, account.Settings.PeerLoginExpiration) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration) + }, + }, + { + name: "Enable regular users view blocked", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + RegularUsersViewBlocked: true, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, true, account.Settings.RegularUsersViewBlocked) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked) + }, + }, + { + name: "Enable groups propagation", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + GroupsPropagationEnabled: &trueVal, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.NotNil(t, account.Settings.GroupsPropagationEnabled) + assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled) + }, + }, + { + name: "Enable JWT groups", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + GroupsPropagationEnabled: &falseVal, + JwtGroupsEnabled: &trueVal, + JwtGroupsClaimName: stringPointer("groups"), + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.NotNil(t, account.Settings.JwtGroupsEnabled) + assert.Equal(t, true, *account.Settings.JwtGroupsEnabled) + assert.NotNil(t, account.Settings.JwtGroupsClaimName) + assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled) + assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + got := &api.Account{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, testing_tools.TestAccountId, got.Id) + assert.Equal(t, "test.com", got.Domain) + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbAccount := testing_tools.VerifyAccountSettings(t, db) + tc.verifyDB(t, dbAccount) + }) + } + } +} + +func stringPointer(s string) *string { + return &s +} diff --git a/management/server/http/testing/integration/dns_handler_integration_test.go b/management/server/http/testing/integration/dns_handler_integration_test.go new file mode 100644 index 000000000..7ada5e462 --- /dev/null +++ b/management/server/http/testing/integration/dns_handler_integration_test.go @@ -0,0 +1,554 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Nameservers_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all nameservers", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/nameservers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.NameserverGroup{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testNSGroup", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Nameservers_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + expectedStatus int + expectGroup bool + }{ + { + name: "Get existing nameserver group", + nsGroupId: "testNSGroupId", + expectedStatus: http.StatusOK, + expectGroup: true, + }, + { + name: "Get non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + expectedStatus: http.StatusNotFound, + expectGroup: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectGroup { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testNSGroupId", got.Id) + assert.Equal(t, "testNSGroup", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Nameservers_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.PostApiDnsNameserversJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup) + }{ + { + name: "Create nameserver group with single NS", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "newNSGroup", + Description: "a new nameserver group", + Nameservers: []api.Nameserver{ + {Ip: "8.8.8.8", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: false, + Domains: []string{"test.com"}, + Enabled: true, + SearchDomainsEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.NotEmpty(t, nsGroup.Id) + assert.Equal(t, "newNSGroup", nsGroup.Name) + assert.Equal(t, 1, len(nsGroup.Nameservers)) + assert.Equal(t, false, nsGroup.Primary) + }, + }, + { + name: "Create primary nameserver group", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "primaryNS", + Description: "primary nameserver", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.Equal(t, true, nsGroup.Primary) + }, + }, + { + name: "Create nameserver group with empty groups", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "emptyGroupsNS", + Description: "no groups", + Nameservers: []api.Nameserver{ + {Ip: "8.8.8.8", NsType: "udp", Port: 53}, + }, + Groups: []string{}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/nameservers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the created NS group directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbNS := testing_tools.VerifyNSGroupInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbNS.Name) + assert.Equal(t, got.Primary, dbNS.Primary) + assert.Equal(t, len(got.Nameservers), len(dbNS.NameServers)) + assert.Equal(t, got.Enabled, dbNS.Enabled) + assert.Equal(t, got.SearchDomainsEnabled, dbNS.SearchDomainsEnabled) + } + }) + } + } +} + +func Test_Nameservers_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + requestBody *api.PutApiDnsNameserversNsgroupIdJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup) + }{ + { + name: "Update nameserver group name", + nsGroupId: "testNSGroupId", + requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{ + Name: "updatedNSGroup", + Description: "updated description", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: false, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.Equal(t, "updatedNSGroup", nsGroup.Name) + assert.Equal(t, "updated description", nsGroup.Description) + }, + }, + { + name: "Update non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{ + Name: "whatever", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the updated NS group directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbNS := testing_tools.VerifyNSGroupInDB(t, db, tc.nsGroupId) + assert.Equal(t, "updatedNSGroup", dbNS.Name) + assert.Equal(t, "updated description", dbNS.Description) + assert.Equal(t, false, dbNS.Primary) + assert.Equal(t, true, dbNS.Enabled) + assert.Equal(t, 1, len(dbNS.NameServers)) + assert.Equal(t, false, dbNS.SearchDomainsEnabled) + } + }) + } + } +} + +func Test_Nameservers_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + expectedStatus int + }{ + { + name: "Delete existing nameserver group", + nsGroupId: "testNSGroupId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify deletion in DB for successful deletes by privileged users + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyNSGroupNotInDB(t, db, tc.nsGroupId) + } + }) + } + } +} + +func Test_DnsSettings_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get DNS settings", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/settings", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := &api.DNSSettings{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.NotNil(t, got.DisabledManagementGroups) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_DnsSettings_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.PutApiDnsSettingsJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, settings *api.DNSSettings) + expectedDBDisabledMgmtLen int + expectedDBDisabledMgmtItem string + }{ + { + name: "Update disabled management groups", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, settings *api.DNSSettings) { + t.Helper() + assert.Equal(t, 1, len(settings.DisabledManagementGroups)) + assert.Equal(t, testing_tools.TestGroupId, settings.DisabledManagementGroups[0]) + }, + expectedDBDisabledMgmtLen: 1, + expectedDBDisabledMgmtItem: testing_tools.TestGroupId, + }, + { + name: "Update with empty disabled management groups", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, settings *api.DNSSettings) { + t.Helper() + assert.Equal(t, 0, len(settings.DisabledManagementGroups)) + }, + expectedDBDisabledMgmtLen: 0, + }, + { + name: "Update with non-existing group", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{"nonExistingGroupId"}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/dns/settings", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.DNSSettings{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify DNS settings directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbAccount := testing_tools.VerifyAccountSettings(t, db) + assert.Equal(t, tc.expectedDBDisabledMgmtLen, len(dbAccount.DNSSettings.DisabledManagementGroups)) + if tc.expectedDBDisabledMgmtItem != "" { + assert.Contains(t, dbAccount.DNSSettings.DisabledManagementGroups, tc.expectedDBDisabledMgmtItem) + } + } + }) + } + } +} diff --git a/management/server/http/testing/integration/events_handler_integration_test.go b/management/server/http/testing/integration/events_handler_integration_test.go new file mode 100644 index 000000000..6611b60ee --- /dev/null +++ b/management/server/http/testing/integration/events_handler_integration_test.go @@ -0,0 +1,105 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Events_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all events", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false) + + // First, perform a mutation to generate an event (create a group as admin) + groupBody, err := json.Marshal(&api.GroupRequest{Name: "eventTestGroup"}) + if err != nil { + t.Fatalf("Failed to marshal group request: %v", err) + } + createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId) + createRecorder := httptest.NewRecorder() + apiHandler.ServeHTTP(createRecorder, createReq) + assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event") + + // Now query events + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Event{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group") + + // Verify the group creation event exists + found := false + for _, event := range got { + if event.ActivityCode == "group.add" { + found = true + assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId) + assert.Equal(t, "Group created", event.Activity) + break + } + } + assert.True(t, found, "Expected to find a group.add event") + }) + } +} + +func Test_Events_GetAll_Empty(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + if !expectResponse { + return + } + + got := []api.Event{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got), "Expected empty events list when no mutations have been performed") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } +} diff --git a/management/server/http/testing/integration/groups_handler_integration_test.go b/management/server/http/testing/integration/groups_handler_integration_test.go new file mode 100644 index 000000000..edb43f3f3 --- /dev/null +++ b/management/server/http/testing/integration/groups_handler_integration_test.go @@ -0,0 +1,382 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Groups_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all groups", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/groups", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Group{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 2) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Groups_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + expectedStatus int + expectGroup bool + }{ + { + name: "Get existing group", + groupId: testing_tools.TestGroupId, + expectedStatus: http.StatusOK, + expectGroup: true, + }, + { + name: "Get non-existing group", + groupId: "nonExistingGroupId", + expectedStatus: http.StatusNotFound, + expectGroup: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectGroup { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.groupId, got.Id) + assert.Equal(t, "testGroupName", got.Name) + assert.Equal(t, 1, got.PeersCount) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Groups_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.GroupRequest + expectedStatus int + verifyResponse func(t *testing.T, group *api.Group) + }{ + { + name: "Create group with valid name", + requestBody: &api.GroupRequest{ + Name: "brandNewGroup", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.NotEmpty(t, group.Id) + assert.Equal(t, "brandNewGroup", group.Name) + assert.Equal(t, 0, group.PeersCount) + }, + }, + { + name: "Create group with peers", + requestBody: &api.GroupRequest{ + Name: "groupWithPeers", + Peers: &[]string{testing_tools.TestPeerId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.NotEmpty(t, group.Id) + assert.Equal(t, "groupWithPeers", group.Name) + assert.Equal(t, 1, group.PeersCount) + }, + }, + { + name: "Create group with empty name", + requestBody: &api.GroupRequest{ + Name: "", + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/groups", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify group exists in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id) + assert.Equal(t, tc.requestBody.Name, dbGroup.Name) + } + }) + } + } +} + +func Test_Groups_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + requestBody *api.GroupRequest + expectedStatus int + verifyResponse func(t *testing.T, group *api.Group) + }{ + { + name: "Update group name", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "updatedGroupName", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.Equal(t, testing_tools.TestGroupId, group.Id) + assert.Equal(t, "updatedGroupName", group.Name) + }, + }, + { + name: "Update group peers", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "testGroupName", + Peers: &[]string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.Equal(t, 0, group.PeersCount) + }, + }, + { + name: "Update with empty name", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "", + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Update non-existing group", + groupId: "nonExistingGroupId", + requestBody: &api.GroupRequest{ + Name: "someName", + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated group in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbGroup := testing_tools.VerifyGroupInDB(t, db, tc.groupId) + assert.Equal(t, tc.requestBody.Name, dbGroup.Name) + } + }) + } + } +} + +func Test_Groups_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + expectedStatus int + }{ + { + name: "Delete existing group not in use", + groupId: testing_tools.NewGroupId, + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing group", + groupId: "nonExistingGroupId", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyGroupNotInDB(t, db, tc.groupId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go new file mode 100644 index 000000000..54f204a8f --- /dev/null +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -0,0 +1,1443 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Networks_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all networks", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.Network{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testNetworkId", got[0].Id) + assert.Equal(t, "testNetwork", got[0].Name) + assert.Equal(t, "test network description", *got[0].Description) + assert.GreaterOrEqual(t, len(got[0].Routers), 1) + assert.GreaterOrEqual(t, len(got[0].Resources), 1) + assert.GreaterOrEqual(t, got[0].RoutingPeersCount, 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Networks_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + expectedStatus int + expectNetwork bool + }{ + { + name: "Get existing network", + networkId: "testNetworkId", + expectedStatus: http.StatusOK, + expectNetwork: true, + }, + { + name: "Get non-existing network", + networkId: "nonExistingNetworkId", + expectedStatus: http.StatusNotFound, + expectNetwork: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectNetwork { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.networkId, got.Id) + assert.Equal(t, "testNetwork", got.Name) + assert.Equal(t, "test network description", *got.Description) + assert.GreaterOrEqual(t, len(got.Routers), 1) + assert.GreaterOrEqual(t, len(got.Resources), 1) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Networks_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + desc := "new network description" + + tt := []struct { + name string + requestBody *api.NetworkRequest + expectedStatus int + verifyResponse func(t *testing.T, network *api.Network) + }{ + { + name: "Create network with name and description", + requestBody: &api.NetworkRequest{ + Name: "newNetwork", + Description: &desc, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.NotEmpty(t, network.Id) + assert.Equal(t, "newNetwork", network.Name) + assert.Equal(t, "new network description", *network.Description) + assert.Empty(t, network.Routers) + assert.Empty(t, network.Resources) + assert.Equal(t, 0, network.RoutingPeersCount) + }, + }, + { + name: "Create network with name only", + requestBody: &api.NetworkRequest{ + Name: "simpleNetwork", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.NotEmpty(t, network.Id) + assert.Equal(t, "simpleNetwork", network.Name) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/networks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_Networks_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + updatedDesc := "updated description" + + tt := []struct { + name string + networkId string + requestBody *api.NetworkRequest + expectedStatus int + verifyResponse func(t *testing.T, network *api.Network) + }{ + { + name: "Update network name", + networkId: "testNetworkId", + requestBody: &api.NetworkRequest{ + Name: "updatedNetwork", + Description: &updatedDesc, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.Equal(t, "testNetworkId", network.Id) + assert.Equal(t, "updatedNetwork", network.Name) + assert.Equal(t, "updated description", *network.Description) + }, + }, + { + name: "Update non-existing network", + networkId: "nonExistingNetworkId", + requestBody: &api.NetworkRequest{ + Name: "whatever", + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_Networks_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + expectedStatus int + }{ + { + name: "Delete existing network", + networkId: "testNetworkId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing network", + networkId: "nonExistingNetworkId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_Networks_Delete_Cascades(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + // Delete the network + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/networks/testNetworkId", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + + // Verify network is gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + testing_tools.ReadResponse(t, recorder, http.StatusNotFound, true) + + // Verify routers in that network are gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/routers", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + content, _ := testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + var routers []*api.NetworkRouter + require.NoError(t, json.Unmarshal(content, &routers)) + assert.Empty(t, routers) + + // Verify resources in that network are gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/resources", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + content, _ = testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + var resources []*api.NetworkResource + require.NoError(t, json.Unmarshal(content, &resources)) + assert.Empty(t, resources) +} + +func Test_NetworkResources_GetAllInNetwork(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all resources in network", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/resources", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkResource{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testResourceId", got[0].Id) + assert.Equal(t, "testResource", got[0].Name) + assert.Equal(t, api.NetworkResourceType("host"), got[0].Type) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkResources_GetAllInAccount(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all resources in account", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/resources", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkResource{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkResources_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + resourceId string + expectedStatus int + expectResource bool + }{ + { + name: "Get existing resource", + networkId: "testNetworkId", + resourceId: "testResourceId", + expectedStatus: http.StatusOK, + expectResource: true, + }, + { + name: "Get non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + expectedStatus: http.StatusNotFound, + expectResource: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectResource { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.resourceId, got.Id) + assert.Equal(t, "testResource", got.Name) + assert.Equal(t, api.NetworkResourceType("host"), got.Type) + assert.Equal(t, "3.3.3.3/32", got.Address) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_NetworkResources_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + desc := "new resource" + + tt := []struct { + name string + networkId string + requestBody *api.NetworkResourceRequest + expectedStatus int + verifyResponse func(t *testing.T, resource *api.NetworkResource) + }{ + { + name: "Create host resource with IP", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "hostResource", + Description: &desc, + Address: "1.1.1.1", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.NotEmpty(t, resource.Id) + assert.Equal(t, "hostResource", resource.Name) + assert.Equal(t, api.NetworkResourceType("host"), resource.Type) + assert.Equal(t, "1.1.1.1/32", resource.Address) + assert.True(t, resource.Enabled) + }, + }, + { + name: "Create host resource with CIDR /32", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "hostCIDR", + Address: "10.0.0.1/32", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("host"), resource.Type) + assert.Equal(t, "10.0.0.1/32", resource.Address) + }, + }, + { + name: "Create subnet resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "subnetResource", + Address: "192.168.0.0/24", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("subnet"), resource.Type) + assert.Equal(t, "192.168.0.0/24", resource.Address) + }, + }, + { + name: "Create domain resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "domainResource", + Address: "example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "example.com", resource.Address) + }, + }, + { + name: "Create wildcard domain resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "wildcardDomain", + Address: "*.example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "*.example.com", resource.Address) + }, + }, + { + name: "Create disabled resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "disabledResource", + Address: "5.5.5.5", + Groups: []string{testing_tools.TestGroupId}, + Enabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.False(t, resource.Enabled) + }, + }, + { + name: "Create resource with invalid address", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "invalidResource", + Address: "not-a-valid-address!!!", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Create resource with empty groups", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "noGroupsResource", + Address: "7.7.7.7", + Groups: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.NotEmpty(t, resource.Id) + }, + }, + { + name: "Create resource with duplicate name", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "8.8.8.8", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/resources", tc.networkId) + req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkResources_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + updatedDesc := "updated resource" + + tt := []struct { + name string + networkId string + resourceId string + requestBody *api.NetworkResourceRequest + expectedStatus int + verifyResponse func(t *testing.T, resource *api.NetworkResource) + }{ + { + name: "Update resource name and address", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "updatedResource", + Description: &updatedDesc, + Address: "4.4.4.4", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, "testResourceId", resource.Id) + assert.Equal(t, "updatedResource", resource.Name) + assert.Equal(t, "updated resource", *resource.Description) + assert.Equal(t, "4.4.4.4/32", resource.Address) + }, + }, + { + name: "Update resource to subnet type", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "10.0.0.0/16", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("subnet"), resource.Type) + assert.Equal(t, "10.0.0.0/16", resource.Address) + }, + }, + { + name: "Update resource to domain type", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "myservice.example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "myservice.example.com", resource.Address) + }, + }, + { + name: "Update non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "whatever", + Address: "1.2.3.4", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkResources_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + resourceId string + expectedStatus int + }{ + { + name: "Delete existing resource", + networkId: "testNetworkId", + resourceId: "testResourceId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_NetworkRouters_GetAllInNetwork(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routers in network", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/routers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkRouter{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testRouterId", got[0].Id) + assert.Equal(t, "testPeerId", *got[0].Peer) + assert.True(t, got[0].Masquerade) + assert.Equal(t, 100, got[0].Metric) + assert.True(t, got[0].Enabled) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkRouters_GetAllInAccount(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routers in account", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/routers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkRouter{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkRouters_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + routerId string + expectedStatus int + expectRouter bool + }{ + { + name: "Get existing router", + networkId: "testNetworkId", + routerId: "testRouterId", + expectedStatus: http.StatusOK, + expectRouter: true, + }, + { + name: "Get non-existing router", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + expectedStatus: http.StatusNotFound, + expectRouter: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectRouter { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.routerId, got.Id) + assert.Equal(t, "testPeerId", *got.Peer) + assert.True(t, got.Masquerade) + assert.Equal(t, 100, got.Metric) + assert.True(t, got.Enabled) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_NetworkRouters_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + peerID := "testPeerId" + peerGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + networkId string + requestBody *api.NetworkRouterRequest + expectedStatus int + verifyResponse func(t *testing.T, router *api.NetworkRouter) + }{ + { + name: "Create router with peer", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 200, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotEmpty(t, router.Id) + assert.Equal(t, peerID, *router.Peer) + assert.True(t, router.Masquerade) + assert.Equal(t, 200, router.Metric) + assert.True(t, router.Enabled) + }, + }, + { + name: "Create router with peer groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + PeerGroups: &peerGroups, + Masquerade: false, + Metric: 300, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotEmpty(t, router.Id) + assert.NotNil(t, router.PeerGroups) + assert.Equal(t, 1, len(*router.PeerGroups)) + assert.False(t, router.Masquerade) + assert.Equal(t, 300, router.Metric) + assert.True(t, router.Enabled) // always true on creation + }, + }, + { + name: "Create router with both peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Create router without peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Create router in non-existing network", + networkId: "nonExistingNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + { + name: "Create router enabled is always true", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: false, + Metric: 50, + Enabled: false, // handler sets to true + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.True(t, router.Enabled) // always true on creation + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/routers", tc.networkId) + req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkRouters_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + peerID := "testPeerId" + peerGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + networkId string + routerId string + requestBody *api.NetworkRouterRequest + expectedStatus int + verifyResponse func(t *testing.T, router *api.NetworkRouter) + }{ + { + name: "Update router metric and masquerade", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: false, + Metric: 500, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.Equal(t, "testRouterId", router.Id) + assert.False(t, router.Masquerade) + assert.Equal(t, 500, router.Metric) + }, + }, + { + name: "Update router to use peer groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotNil(t, router.PeerGroups) + assert.Equal(t, 1, len(*router.PeerGroups)) + }, + }, + { + name: "Update router disabled", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.False(t, router.Enabled) + }, + }, + { + name: "Update non-existing router creates it", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.Equal(t, "nonExistingRouterId", router.Id) + }, + }, + { + name: "Update router with both peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Update router without peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkRouters_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + routerId string + expectedStatus int + }{ + { + name: "Delete existing router", + networkId: "testNetworkId", + routerId: "testRouterId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing router", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} diff --git a/management/server/http/testing/integration/peers_handler_integration_test.go b/management/server/http/testing/integration/peers_handler_integration_test.go new file mode 100644 index 000000000..17a9e94a6 --- /dev/null +++ b/management/server/http/testing/integration/peers_handler_integration_test.go @@ -0,0 +1,605 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +const ( + testPeerId2 = "testPeerId2" +) + +func Test_Peers_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: true, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + for _, user := range users { + t.Run(user.name+" - Get all peers", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/peers", user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + var got []api.PeerBatch + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 2, "Expected at least 2 peers") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Peers_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: true, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + verifyResponse func(t *testing.T, peer *api.Peer) + }{ + { + name: "Get existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "test-peer-1", peer.Name) + assert.Equal(t, "test-host-1", peer.Hostname) + assert.Equal(t, "Debian GNU/Linux ", peer.Os) + assert.Equal(t, "0.12.0", peer.Version) + assert.Equal(t, false, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Get second existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: testPeerId2, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testPeerId2, peer.Id) + assert.Equal(t, "test-peer-2", peer.Name) + assert.Equal(t, "test-host-2", peer.Hostname) + assert.Equal(t, "Ubuntu ", peer.Os) + assert.Equal(t, true, peer.SshEnabled) + assert.Equal(t, false, peer.LoginExpirationEnabled) + assert.Equal(t, true, peer.Connected) + }, + }, + { + name: "Get non-existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusNotFound, + verifyResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Peer{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Peers_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestBody *api.PeerRequest + requestType string + requestPath string + requestId string + verifyResponse func(t *testing.T, peer *api.Peer) + }{ + { + name: "Update peer name", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "updated-peer-name", + SshEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "updated-peer-name", peer.Name) + assert.Equal(t, false, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Enable SSH on peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "test-peer-1", + SshEnabled: true, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "test-peer-1", peer.Name) + assert.Equal(t, true, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Disable login expiration on peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "test-peer-1", + SshEnabled: false, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, false, peer.LoginExpirationEnabled) + }, + }, + { + name: "Update non-existing peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + requestBody: &api.PeerRequest{ + Name: "updated-name", + SshEnabled: false, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusNotFound, + verifyResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Peer{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated peer in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPeer := testing_tools.VerifyPeerInDB(t, db, tc.requestId) + assert.Equal(t, tc.requestBody.Name, dbPeer.Name) + assert.Equal(t, tc.requestBody.SshEnabled, dbPeer.SSHEnabled) + assert.Equal(t, tc.requestBody.LoginExpirationEnabled, dbPeer.LoginExpirationEnabled) + assert.Equal(t, tc.requestBody.InactivityExpirationEnabled, dbPeer.InactivityExpirationEnabled) + } + }) + } + } +} + +func Test_Peers_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing peer", + requestType: http.MethodDelete, + requestPath: "/api/peers/{peerId}", + requestId: testPeerId2, + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing peer", + requestType: http.MethodDelete, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + // Verify peer is actually deleted in DB + if tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPeerNotInDB(t, db, tc.requestId) + } + }) + } + } +} + +func Test_Peers_GetAccessiblePeers(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + }{ + { + name: "Get accessible peers for existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}/accessible-peers", + requestId: testing_tools.TestPeerId, + expectedStatus: http.StatusOK, + }, + { + name: "Get accessible peers for non-existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}/accessible-peers", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusOK, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectedStatus == http.StatusOK { + var got []api.AccessiblePeer + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + // The accessible peers list should be a valid array (may be empty if no policies connect peers) + assert.NotNil(t, got, "Expected accessible peers to be a valid array") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} diff --git a/management/server/http/testing/integration/policies_handler_integration_test.go b/management/server/http/testing/integration/policies_handler_integration_test.go new file mode 100644 index 000000000..6f3624fb5 --- /dev/null +++ b/management/server/http/testing/integration/policies_handler_integration_test.go @@ -0,0 +1,488 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Policies_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all policies", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/policies", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Policy{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testPolicy", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Policies_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + policyId string + expectedStatus int + expectPolicy bool + }{ + { + name: "Get existing policy", + policyId: "testPolicyId", + expectedStatus: http.StatusOK, + expectPolicy: true, + }, + { + name: "Get non-existing policy", + policyId: "nonExistingPolicyId", + expectedStatus: http.StatusNotFound, + expectPolicy: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectPolicy { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotNil(t, got.Id) + assert.Equal(t, tc.policyId, *got.Id) + assert.Equal(t, "testPolicy", got.Name) + assert.Equal(t, true, got.Enabled) + assert.GreaterOrEqual(t, len(got.Rules), 1) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Policies_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + srcGroups := []string{testing_tools.TestGroupId} + dstGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + requestBody *api.PolicyCreate + expectedStatus int + verifyResponse func(t *testing.T, policy *api.Policy) + }{ + { + name: "Create policy with accept rule", + requestBody: &api.PolicyCreate{ + Name: "newPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "allowAll", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.NotNil(t, policy.Id) + assert.Equal(t, "newPolicy", policy.Name) + assert.Equal(t, true, policy.Enabled) + assert.Equal(t, 1, len(policy.Rules)) + assert.Equal(t, "allowAll", policy.Rules[0].Name) + }, + }, + { + name: "Create policy with drop rule", + requestBody: &api.PolicyCreate{ + Name: "dropPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "dropAll", + Enabled: true, + Action: "drop", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "dropPolicy", policy.Name) + }, + }, + { + name: "Create policy with TCP rule and ports", + requestBody: &api.PolicyCreate{ + Name: "tcpPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "tcpRule", + Enabled: true, + Action: "accept", + Protocol: "tcp", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + Ports: &[]string{"80", "443"}, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "tcpPolicy", policy.Name) + assert.NotNil(t, policy.Rules[0].Ports) + assert.Equal(t, 2, len(*policy.Rules[0].Ports)) + }, + }, + { + name: "Create policy with empty name", + requestBody: &api.PolicyCreate{ + Name: "", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "rule", + Enabled: true, + Action: "accept", + Protocol: "all", + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create policy with no rules", + requestBody: &api.PolicyCreate{ + Name: "noRulesPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/policies", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify policy exists in DB with correct fields + db := testing_tools.GetDB(t, am.GetStore()) + dbPolicy := testing_tools.VerifyPolicyInDB(t, db, *got.Id) + assert.Equal(t, tc.requestBody.Name, dbPolicy.Name) + assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled) + assert.Equal(t, len(tc.requestBody.Rules), len(dbPolicy.Rules)) + } + }) + } + } +} + +func Test_Policies_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + srcGroups := []string{testing_tools.TestGroupId} + dstGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + policyId string + requestBody *api.PolicyCreate + expectedStatus int + verifyResponse func(t *testing.T, policy *api.Policy) + }{ + { + name: "Update policy name", + policyId: "testPolicyId", + requestBody: &api.PolicyCreate{ + Name: "updatedPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "testRule", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "updatedPolicy", policy.Name) + }, + }, + { + name: "Update policy enabled state", + policyId: "testPolicyId", + requestBody: &api.PolicyCreate{ + Name: "testPolicy", + Enabled: false, + Rules: []api.PolicyRuleUpdate{ + { + Name: "testRule", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, false, policy.Enabled) + }, + }, + { + name: "Update non-existing policy", + policyId: "nonExistingPolicyId", + requestBody: &api.PolicyCreate{ + Name: "whatever", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "rule", + Enabled: true, + Action: "accept", + Protocol: "all", + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated policy in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPolicy := testing_tools.VerifyPolicyInDB(t, db, tc.policyId) + assert.Equal(t, tc.requestBody.Name, dbPolicy.Name) + assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled) + } + }) + } + } +} + +func Test_Policies_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + policyId string + expectedStatus int + }{ + { + name: "Delete existing policy", + policyId: "testPolicyId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing policy", + policyId: "nonExistingPolicyId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPolicyNotInDB(t, db, tc.policyId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/routes_handler_integration_test.go b/management/server/http/testing/integration/routes_handler_integration_test.go new file mode 100644 index 000000000..eeb0c3025 --- /dev/null +++ b/management/server/http/testing/integration/routes_handler_integration_test.go @@ -0,0 +1,455 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Routes_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routes", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/routes", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Route{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 2, len(got)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Routes_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + routeId string + expectedStatus int + expectRoute bool + }{ + { + name: "Get existing route", + routeId: "testRouteId", + expectedStatus: http.StatusOK, + expectRoute: true, + }, + { + name: "Get non-existing route", + routeId: "nonExistingRouteId", + expectedStatus: http.StatusNotFound, + expectRoute: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectRoute { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.routeId, got.Id) + assert.Equal(t, "Test Network Route", got.Description) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Routes_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + networkCIDR := "10.10.0.0/24" + peerID := testing_tools.TestPeerId + peerGroups := []string{"peerGroupId"} + + tt := []struct { + name string + requestBody *api.RouteRequest + expectedStatus int + verifyResponse func(t *testing.T, route *api.Route) + }{ + { + name: "Create network route with peer", + requestBody: &api.RouteRequest{ + Description: "New network route", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "newNet", + Metric: 100, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.NotEmpty(t, route.Id) + assert.Equal(t, "New network route", route.Description) + assert.Equal(t, 100, route.Metric) + assert.Equal(t, true, route.Masquerade) + assert.Equal(t, true, route.Enabled) + }, + }, + { + name: "Create network route with peer groups", + requestBody: &api.RouteRequest{ + Description: "Route with peer groups", + Network: &networkCIDR, + PeerGroups: &peerGroups, + NetworkId: "peerGroupNet", + Metric: 150, + Masquerade: false, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.NotEmpty(t, route.Id) + assert.Equal(t, "Route with peer groups", route.Description) + }, + }, + { + name: "Create route with empty network_id", + requestBody: &api.RouteRequest{ + Description: "Empty net id", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "", + Metric: 100, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create route with metric 0", + requestBody: &api.RouteRequest{ + Description: "Zero metric", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "zeroMetric", + Metric: 0, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create route with metric 10000", + requestBody: &api.RouteRequest{ + Description: "High metric", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "highMetric", + Metric: 10000, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/routes", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify route exists in DB with correct fields + db := testing_tools.GetDB(t, am.GetStore()) + dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id)) + assert.Equal(t, tc.requestBody.Description, dbRoute.Description) + assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric) + assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade) + assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled) + assert.Equal(t, route.NetID(tc.requestBody.NetworkId), dbRoute.NetID) + } + }) + } + } +} + +func Test_Routes_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + networkCIDR := "10.0.0.0/24" + peerID := testing_tools.TestPeerId + + tt := []struct { + name string + routeId string + requestBody *api.RouteRequest + expectedStatus int + verifyResponse func(t *testing.T, route *api.Route) + }{ + { + name: "Update route description", + routeId: "testRouteId", + requestBody: &api.RouteRequest{ + Description: "Updated description", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 100, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.Equal(t, "testRouteId", route.Id) + assert.Equal(t, "Updated description", route.Description) + }, + }, + { + name: "Update route metric", + routeId: "testRouteId", + requestBody: &api.RouteRequest{ + Description: "Test Network Route", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 500, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.Equal(t, 500, route.Metric) + }, + }, + { + name: "Update non-existing route", + routeId: "nonExistingRouteId", + requestBody: &api.RouteRequest{ + Description: "whatever", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 100, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated route in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id)) + assert.Equal(t, tc.requestBody.Description, dbRoute.Description) + assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric) + assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade) + assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled) + } + }) + } + } +} + +func Test_Routes_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + routeId string + expectedStatus int + }{ + { + name: "Delete existing route", + routeId: "testRouteId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing route", + routeId: "nonExistingRouteId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify route was deleted from DB for successful deletes + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyRouteNotInDB(t, db, route.ID(tc.routeId)) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index c1a9829da..0d3aaac82 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -3,7 +3,6 @@ package integration import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -14,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/http/api" @@ -254,7 +252,7 @@ func Test_SetupKeys_Create(t *testing.T) { expectedResponse: nil, }, { - name: "Create Setup Key", + name: "Create Setup Key with nil AutoGroups", requestType: http.MethodPost, requestPath: "/api/setup-keys", requestBody: &api.CreateSetupKeyRequest{ @@ -308,14 +306,15 @@ func Test_SetupKeys_Create(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify setup key exists in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, tc.expectedResponse.Name, dbKey.Name) + assert.Equal(t, tc.expectedResponse.Revoked, dbKey.Revoked) + assert.Equal(t, tc.expectedResponse.UsageLimit, dbKey.UsageLimit) select { case <-done: @@ -571,7 +570,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { - t.Run(tc.name, func(t *testing.T) { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) @@ -594,14 +593,16 @@ func Test_SetupKeys_Update(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id + gotRevoked := got.Revoked + gotUsageLimit := got.UsageLimit validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify updated setup key in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotRevoked, dbKey.Revoked) + assert.Equal(t, gotUsageLimit, dbKey.UsageLimit) select { case <-done: @@ -759,8 +760,8 @@ func Test_SetupKeys_Get(t *testing.T) { apiHandler.ServeHTTP(recorder, req) - content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) - if !expectRespnose { + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { return } got := &api.SetupKey{} @@ -768,14 +769,16 @@ func Test_SetupKeys_Get(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id + gotName := got.Name + gotRevoked := got.Revoked validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify setup key in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotName, dbKey.Name) + assert.Equal(t, gotRevoked, dbKey.Revoked) select { case <-done: @@ -928,15 +931,17 @@ func Test_SetupKeys_GetAll(t *testing.T) { return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit }) + db := testing_tools.GetDB(t, am.GetStore()) for i := range tc.expectedResponse { + gotID := got[i].Id + gotName := got[i].Name + gotRevoked := got[i].Revoked validateCreatedKey(t, tc.expectedResponse[i], &got[i]) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + // Verify each setup key in DB via gorm + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotName, dbKey.Name) + assert.Equal(t, gotRevoked, dbKey.Revoked) } select { @@ -1104,8 +1109,9 @@ func Test_SetupKeys_Delete(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } - _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - assert.Errorf(t, err, "Expected error when trying to get deleted key") + // Verify setup key deleted from DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifySetupKeyNotInDB(t, db, got.Id) select { case <-done: @@ -1120,7 +1126,7 @@ func Test_SetupKeys_Delete(t *testing.T) { func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { t.Helper() - if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + if (got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second))) || got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { got.Expires = time.Time{} diff --git a/management/server/http/testing/integration/users_handler_integration_test.go b/management/server/http/testing/integration/users_handler_integration_test.go new file mode 100644 index 000000000..eae3b4ad5 --- /dev/null +++ b/management/server/http/testing/integration/users_handler_integration_test.go @@ -0,0 +1,701 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Users_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, true}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all users", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.User{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Users_GetAll_ServiceUsers(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all service users", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users?service_user=true", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.User{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + for _, u := range got { + assert.NotNil(t, u.IsServiceUser) + assert.Equal(t, true, *u.IsServiceUser) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Users_Create_ServiceUser(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.UserCreateRequest + expectedStatus int + verifyResponse func(t *testing.T, user *api.User) + }{ + { + name: "Create service user with admin role", + requestBody: &api.UserCreateRequest{ + Role: "admin", + IsServiceUser: true, + AutoGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + assert.Equal(t, "admin", user.Role) + assert.NotNil(t, user.IsServiceUser) + assert.Equal(t, true, *user.IsServiceUser) + }, + }, + { + name: "Create service user with user role", + requestBody: &api.UserCreateRequest{ + Role: "user", + IsServiceUser: true, + AutoGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + assert.Equal(t, "user", user.Role) + }, + }, + { + name: "Create service user with empty auto_groups", + requestBody: &api.UserCreateRequest{ + Role: "admin", + IsServiceUser: true, + AutoGroups: []string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.User{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify user in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbUser := testing_tools.VerifyUserInDB(t, db, got.Id) + assert.True(t, dbUser.IsServiceUser) + assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Users_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + requestBody *api.UserRequest + expectedStatus int + verifyResponse func(t *testing.T, user *api.User) + }{ + { + name: "Update user role to admin", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "admin", + AutoGroups: []string{}, + IsBlocked: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, "admin", user.Role) + }, + }, + { + name: "Update user auto_groups", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{testing_tools.TestGroupId}, + IsBlocked: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, 1, len(user.AutoGroups)) + }, + }, + { + name: "Block user", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{}, + IsBlocked: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, true, user.IsBlocked) + }, + }, + { + name: "Update non-existing user", + targetUserId: "nonExistingUserId", + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{}, + IsBlocked: false, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.User{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated fields in DB + if tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + dbUser := testing_tools.VerifyUserInDB(t, db, tc.targetUserId) + assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role)) + assert.Equal(t, dbUser.Blocked, tc.requestBody.IsBlocked) + assert.ElementsMatch(t, dbUser.AutoGroups, tc.requestBody.AutoGroups) + } + } + }) + } + } +} + +func Test_Users_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Delete existing service user", + targetUserId: "deletableServiceUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify user deleted from DB for successful deletes + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all PATs for service user", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/users/{userId}/tokens", "{userId}", testing_tools.TestServiceUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.PersonalAccessToken{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "serviceToken", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_PATs_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + tokenId string + expectedStatus int + expectToken bool + }{ + { + name: "Get existing PAT", + tokenId: "serviceTokenId", + expectedStatus: http.StatusOK, + expectToken: true, + }, + { + name: "Get non-existing PAT", + tokenId: "nonExistingTokenId", + expectedStatus: http.StatusNotFound, + expectToken: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1) + path = strings.Replace(path, "{tokenId}", tc.tokenId, 1) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectToken { + got := &api.PersonalAccessToken{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "serviceTokenId", got.Id) + assert.Equal(t, "serviceToken", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + requestBody *api.PersonalAccessTokenRequest + expectedStatus int + verifyResponse func(t *testing.T, pat *api.PersonalAccessTokenGenerated) + }{ + { + name: "Create PAT with 30 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "newPAT", + ExpiresIn: 30, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) { + t.Helper() + assert.NotEmpty(t, pat.PlainToken) + assert.Equal(t, "newPAT", pat.PersonalAccessToken.Name) + }, + }, + { + name: "Create PAT with 365 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "longPAT", + ExpiresIn: 365, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) { + t.Helper() + assert.NotEmpty(t, pat.PlainToken) + assert.Equal(t, "longPAT", pat.PersonalAccessToken.Name) + }, + }, + { + name: "Create PAT with empty name", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "", + ExpiresIn: 30, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create PAT with 0 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "zeroPAT", + ExpiresIn: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create PAT with expiry over 365 days", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "tooLongPAT", + ExpiresIn: 400, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, strings.Replace("/api/users/{userId}/tokens", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PersonalAccessTokenGenerated{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify PAT in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPAT := testing_tools.VerifyPATInDB(t, db, got.PersonalAccessToken.Id) + assert.Equal(t, tc.requestBody.Name, dbPAT.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + tokenId string + expectedStatus int + }{ + { + name: "Delete existing PAT", + tokenId: "serviceTokenId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing PAT", + tokenId: "nonExistingTokenId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1) + path = strings.Replace(path, "{tokenId}", tc.tokenId, 1) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify PAT deleted from DB for successful deletes + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPATNotInDB(t, db, tc.tokenId) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} diff --git a/management/server/http/testing/testdata/accounts.sql b/management/server/http/testing/testdata/accounts.sql new file mode 100644 index 000000000..35f00d419 --- /dev/null +++ b/management/server/http/testing/testdata/accounts.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); diff --git a/management/server/http/testing/testdata/dns.sql b/management/server/http/testing/testdata/dns.sql new file mode 100644 index 000000000..9ed4daf7e --- /dev/null +++ b/management/server/http/testing/testdata/dns.sql @@ -0,0 +1,21 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/events.sql b/management/server/http/testing/testdata/events.sql new file mode 100644 index 000000000..27fd01aea --- /dev/null +++ b/management/server/http/testing/testdata/events.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/groups.sql b/management/server/http/testing/testdata/groups.sql new file mode 100644 index 000000000..eb874f036 --- /dev/null +++ b/management/server/http/testing/testdata/groups.sql @@ -0,0 +1,19 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/networks.sql b/management/server/http/testing/testdata/networks.sql new file mode 100644 index 000000000..39ec8e646 --- /dev/null +++ b/management/server/http/testing/testdata/networks.sql @@ -0,0 +1,25 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`domain` text,`prefix` text,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO networks VALUES('testNetworkId','testAccountId','testNetwork','test network description'); +INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','testPeerId','[]',1,100,1); +INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','testResource','test resource description','host','','"3.3.3.3/32"',1); \ No newline at end of file diff --git a/management/server/http/testing/testdata/peers_integration.sql b/management/server/http/testing/testdata/peers_integration.sql new file mode 100644 index 000000000..62a7760e7 --- /dev/null +++ b/management/server/http/testing/testdata/peers_integration.sql @@ -0,0 +1,20 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId","testPeerId2"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); + +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','test-host-1','linux','Linux','','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-1','test-peer-1','2023-03-02 09:21:02.189035775+01:00',0,0,0,'testUserId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('testPeerId2','testAccountId','6rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYBg=','82546A29-6BC8-4311-BCFC-9CDBF33F1A49','"100.64.114.32"','test-host-2','linux','Linux','','unknown','Ubuntu','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-2','test-peer-2','2023-03-02 09:21:02.189035775+01:00',1,0,0,'testAdminId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/policies.sql b/management/server/http/testing/testdata/policies.sql new file mode 100644 index 000000000..7e6cc883b --- /dev/null +++ b/management/server/http/testing/testdata/policies.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`protocol` text,`bidirectional` numeric,`sources` text,`destinations` text,`source_resource` text,`destination_resource` text,`ports` text,`port_ranges` text,`authorized_groups` text,`authorized_user` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules_g` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO policies VALUES('testPolicyId','testAccountId','testPolicy','test policy description',1,NULL); +INSERT INTO policy_rules VALUES('testRuleId','testPolicyId','testRule','test rule',1,'accept','all',1,'["testGroupId"]','["testGroupId"]',NULL,NULL,NULL,NULL,NULL,''); \ No newline at end of file diff --git a/management/server/http/testing/testdata/routes.sql b/management/server/http/testing/testdata/routes.sql new file mode 100644 index 000000000..48aa02052 --- /dev/null +++ b/management/server/http/testing/testdata/routes.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,`skip_auto_apply` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('peerGroupId','testAccountId','peerGroupName','api','["testPeerId"]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO routes VALUES('testRouteId','testAccountId','"10.0.0.0/24"',NULL,0,'testNet','Test Network Route','testPeerId',NULL,1,1,100,1,'["testGroupId"]',NULL,0); +INSERT INTO routes VALUES('testDomainRouteId','testAccountId','"0.0.0.0/0"','["example.com"]',0,'testDomainNet','Test Domain Route','','["peerGroupId"]',3,1,200,1,'["testGroupId"]',NULL,0); diff --git a/management/server/http/testing/testdata/users_integration.sql b/management/server/http/testing/testdata/users_integration.sql new file mode 100644 index 000000000..57df73e8c --- /dev/null +++ b/management/server/http/testing/testdata/users_integration.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime DEFAULT NULL,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('deletableServiceUserId','testAccountId','user',1,0,'deletableServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO personal_access_tokens VALUES('testTokenId','testUserId','testToken','hashedTokenValue123','2325-10-02 16:01:38.000000000+00:00','testUserId','2024-10-02 16:01:38.000000000+00:00',NULL); +INSERT INTO personal_access_tokens VALUES('serviceTokenId','testServiceUserId','serviceToken','hashedServiceTokenValue123','2325-10-02 16:01:38.000000000+00:00','testAdminId','2024-10-02 16:01:38.000000000+00:00',NULL); \ No newline at end of file diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index f5c2aafa6..0203d6177 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -9,10 +9,14 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/metric/noop" + "github.com/netbirdio/management-integrations/integrations" + accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -31,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" serverauth "github.com/netbirdio/netbird/management/server/auth" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -83,23 +88,39 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) - domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) - reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) - proxyServiceServer.SetProxyManager(reverseProxyManager) - am.SetServiceManager(reverseProxyManager) + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) + noopMeter := noop.NewMeterProvider().Meter("") + proxyMgr, err := proxymanager.NewManager(store, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy manager: %v", err) + } + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) + serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy controller: %v", err) + } + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) + proxyServiceServer.SetServiceManager(serviceManager) + am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, @@ -107,14 +128,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee GetPATInfoFunc: authManager.GetPATInfo, } - networksManagerMock := networks.NewManagerMock() - resourcesManagerMock := resources.NewManagerMock() - routersManagerMock := routers.NewManagerMock() - groupsManagerMock := groups.NewManagerMock() + groupsManager := groups.NewManager(store, permissionsManager, am) + routersManager := routers.NewManager(store, permissionsManager, am) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -146,6 +167,111 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m } } +// PeerShouldReceiveAnyUpdate waits for a peer update message and returns it. +// Fails the test if no update is received within timeout. +func PeerShouldReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) *network_map.UpdateMessage { + t.Helper() + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + return msg + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + return nil + } +} + +// PeerShouldNotReceiveAnyUpdate verifies no peer update message is received. +func PeerShouldNotReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { + t.Helper() + peerShouldNotReceiveUpdate(t, updateMessage) +} + +// BuildApiBlackBoxWithDBStateAndPeerChannel creates the API handler and returns +// the peer update channel directly so tests can verify updates inline. +func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile string) (http.Handler, account.Manager, <-chan *network_map.UpdateMessage) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) + + geoMock := &geolocation.Mock{} + validatorMock := server.MockIntegratedValidator{} + proxyController := integrations.NewController(store) + userManager := users.NewManager(store) + permissionsManager := permissions.NewManager(store) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) + peersManager := peers.NewManager(store, permissionsManager) + + jobManager := job.NewJobManager(nil, store, peersManager) + + ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) + noopMeter := noop.NewMeterProvider().Meter("") + proxyMgr, err := proxymanager.NewManager(store, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy manager: %v", err) + } + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) + serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy controller: %v", err) + } + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) + proxyServiceServer.SetServiceManager(serviceManager) + am.SetServiceManager(serviceManager) + + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) + authManagerMock := &serverauth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + + groupsManager := groups.NewManager(store, permissionsManager, am) + routersManager := routers.NewManager(store, permissionsManager, am) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) + customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") + zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) + + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, updMsg +} + func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { userAuth := auth.UserAuth{} diff --git a/management/server/http/testing/testing_tools/db_verify.go b/management/server/http/testing/testing_tools/db_verify.go new file mode 100644 index 000000000..f8af6a41f --- /dev/null +++ b/management/server/http/testing/testing_tools/db_verify.go @@ -0,0 +1,222 @@ +package testing_tools + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// GetDB extracts the *gorm.DB from a store.Store (must be *SqlStore). +func GetDB(t *testing.T, s store.Store) *gorm.DB { + t.Helper() + sqlStore, ok := s.(*store.SqlStore) + require.True(t, ok, "Store is not a *SqlStore, cannot get gorm.DB") + return sqlStore.GetDB() +} + +// VerifyGroupInDB reads a group directly from the DB and returns it. +func VerifyGroupInDB(t *testing.T, db *gorm.DB, groupID string) *types.Group { + t.Helper() + var group types.Group + err := db.Where("id = ? AND account_id = ?", groupID, TestAccountId).First(&group).Error + require.NoError(t, err, "Expected group %s to exist in DB", groupID) + return &group +} + +// VerifyGroupNotInDB verifies that a group does not exist in the DB. +func VerifyGroupNotInDB(t *testing.T, db *gorm.DB, groupID string) { + t.Helper() + var count int64 + db.Model(&types.Group{}).Where("id = ? AND account_id = ?", groupID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected group %s to NOT exist in DB", groupID) +} + +// VerifyPolicyInDB reads a policy directly from the DB and returns it. +func VerifyPolicyInDB(t *testing.T, db *gorm.DB, policyID string) *types.Policy { + t.Helper() + var policy types.Policy + err := db.Preload("Rules").Where("id = ? AND account_id = ?", policyID, TestAccountId).First(&policy).Error + require.NoError(t, err, "Expected policy %s to exist in DB", policyID) + return &policy +} + +// VerifyPolicyNotInDB verifies that a policy does not exist in the DB. +func VerifyPolicyNotInDB(t *testing.T, db *gorm.DB, policyID string) { + t.Helper() + var count int64 + db.Model(&types.Policy{}).Where("id = ? AND account_id = ?", policyID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected policy %s to NOT exist in DB", policyID) +} + +// VerifyRouteInDB reads a route directly from the DB and returns it. +func VerifyRouteInDB(t *testing.T, db *gorm.DB, routeID route.ID) *route.Route { + t.Helper() + var r route.Route + err := db.Where("id = ? AND account_id = ?", routeID, TestAccountId).First(&r).Error + require.NoError(t, err, "Expected route %s to exist in DB", routeID) + return &r +} + +// VerifyRouteNotInDB verifies that a route does not exist in the DB. +func VerifyRouteNotInDB(t *testing.T, db *gorm.DB, routeID route.ID) { + t.Helper() + var count int64 + db.Model(&route.Route{}).Where("id = ? AND account_id = ?", routeID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected route %s to NOT exist in DB", routeID) +} + +// VerifyNSGroupInDB reads a nameserver group directly from the DB and returns it. +func VerifyNSGroupInDB(t *testing.T, db *gorm.DB, nsGroupID string) *nbdns.NameServerGroup { + t.Helper() + var nsGroup nbdns.NameServerGroup + err := db.Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).First(&nsGroup).Error + require.NoError(t, err, "Expected NS group %s to exist in DB", nsGroupID) + return &nsGroup +} + +// VerifyNSGroupNotInDB verifies that a nameserver group does not exist in the DB. +func VerifyNSGroupNotInDB(t *testing.T, db *gorm.DB, nsGroupID string) { + t.Helper() + var count int64 + db.Model(&nbdns.NameServerGroup{}).Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected NS group %s to NOT exist in DB", nsGroupID) +} + +// VerifyPeerInDB reads a peer directly from the DB and returns it. +func VerifyPeerInDB(t *testing.T, db *gorm.DB, peerID string) *nbpeer.Peer { + t.Helper() + var peer nbpeer.Peer + err := db.Where("id = ? AND account_id = ?", peerID, TestAccountId).First(&peer).Error + require.NoError(t, err, "Expected peer %s to exist in DB", peerID) + return &peer +} + +// VerifyPeerNotInDB verifies that a peer does not exist in the DB. +func VerifyPeerNotInDB(t *testing.T, db *gorm.DB, peerID string) { + t.Helper() + var count int64 + db.Model(&nbpeer.Peer{}).Where("id = ? AND account_id = ?", peerID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected peer %s to NOT exist in DB", peerID) +} + +// VerifySetupKeyInDB reads a setup key directly from the DB and returns it. +func VerifySetupKeyInDB(t *testing.T, db *gorm.DB, keyID string) *types.SetupKey { + t.Helper() + var key types.SetupKey + err := db.Where("id = ? AND account_id = ?", keyID, TestAccountId).First(&key).Error + require.NoError(t, err, "Expected setup key %s to exist in DB", keyID) + return &key +} + +// VerifySetupKeyNotInDB verifies that a setup key does not exist in the DB. +func VerifySetupKeyNotInDB(t *testing.T, db *gorm.DB, keyID string) { + t.Helper() + var count int64 + db.Model(&types.SetupKey{}).Where("id = ? AND account_id = ?", keyID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected setup key %s to NOT exist in DB", keyID) +} + +// VerifyUserInDB reads a user directly from the DB and returns it. +func VerifyUserInDB(t *testing.T, db *gorm.DB, userID string) *types.User { + t.Helper() + var user types.User + err := db.Where("id = ? AND account_id = ?", userID, TestAccountId).First(&user).Error + require.NoError(t, err, "Expected user %s to exist in DB", userID) + return &user +} + +// VerifyUserNotInDB verifies that a user does not exist in the DB. +func VerifyUserNotInDB(t *testing.T, db *gorm.DB, userID string) { + t.Helper() + var count int64 + db.Model(&types.User{}).Where("id = ? AND account_id = ?", userID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected user %s to NOT exist in DB", userID) +} + +// VerifyPATInDB reads a PAT directly from the DB and returns it. +func VerifyPATInDB(t *testing.T, db *gorm.DB, tokenID string) *types.PersonalAccessToken { + t.Helper() + var pat types.PersonalAccessToken + err := db.Where("id = ?", tokenID).First(&pat).Error + require.NoError(t, err, "Expected PAT %s to exist in DB", tokenID) + return &pat +} + +// VerifyPATNotInDB verifies that a PAT does not exist in the DB. +func VerifyPATNotInDB(t *testing.T, db *gorm.DB, tokenID string) { + t.Helper() + var count int64 + db.Model(&types.PersonalAccessToken{}).Where("id = ?", tokenID).Count(&count) + assert.Equal(t, int64(0), count, "Expected PAT %s to NOT exist in DB", tokenID) +} + +// VerifyAccountSettings reads the account and returns its settings from the DB. +func VerifyAccountSettings(t *testing.T, db *gorm.DB) *types.Account { + t.Helper() + var account types.Account + err := db.Where("id = ?", TestAccountId).First(&account).Error + require.NoError(t, err, "Expected account %s to exist in DB", TestAccountId) + return &account +} + +// VerifyNetworkInDB reads a network directly from the store and returns it. +func VerifyNetworkInDB(t *testing.T, db *gorm.DB, networkID string) *networkTypes.Network { + t.Helper() + var network networkTypes.Network + err := db.Where("id = ? AND account_id = ?", networkID, TestAccountId).First(&network).Error + require.NoError(t, err, "Expected network %s to exist in DB", networkID) + return &network +} + +// VerifyNetworkNotInDB verifies that a network does not exist in the DB. +func VerifyNetworkNotInDB(t *testing.T, db *gorm.DB, networkID string) { + t.Helper() + var count int64 + db.Model(&networkTypes.Network{}).Where("id = ? AND account_id = ?", networkID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network %s to NOT exist in DB", networkID) +} + +// VerifyNetworkResourceInDB reads a network resource directly from the DB and returns it. +func VerifyNetworkResourceInDB(t *testing.T, db *gorm.DB, resourceID string) *resourceTypes.NetworkResource { + t.Helper() + var resource resourceTypes.NetworkResource + err := db.Where("id = ? AND account_id = ?", resourceID, TestAccountId).First(&resource).Error + require.NoError(t, err, "Expected network resource %s to exist in DB", resourceID) + return &resource +} + +// VerifyNetworkResourceNotInDB verifies that a network resource does not exist in the DB. +func VerifyNetworkResourceNotInDB(t *testing.T, db *gorm.DB, resourceID string) { + t.Helper() + var count int64 + db.Model(&resourceTypes.NetworkResource{}).Where("id = ? AND account_id = ?", resourceID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network resource %s to NOT exist in DB", resourceID) +} + +// VerifyNetworkRouterInDB reads a network router directly from the DB and returns it. +func VerifyNetworkRouterInDB(t *testing.T, db *gorm.DB, routerID string) *routerTypes.NetworkRouter { + t.Helper() + var router routerTypes.NetworkRouter + err := db.Where("id = ? AND account_id = ?", routerID, TestAccountId).First(&router).Error + require.NoError(t, err, "Expected network router %s to exist in DB", routerID) + return &router +} + +// VerifyNetworkRouterNotInDB verifies that a network router does not exist in the DB. +func VerifyNetworkRouterNotInDB(t *testing.T, db *gorm.DB, routerID string) { + t.Helper() + var count int64 + db.Model(&routerTypes.NetworkRouter{}).Where("id = ? AND account_id = ?", routerID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network router %s to NOT exist in DB", routerID) +} diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index 9fce6b9c0..d51254c55 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "path/filepath" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -19,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -83,10 +85,15 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update permissionsManager := permissions.NewManager(testStore) peersManager := peers.NewManager(testStore, permissionsManager) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, testStore) networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 8ab4ce0dc..48d3221cc 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/telemetry" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -48,11 +49,13 @@ type EmbeddedIdPConfig struct { // Existing local users are preserved and will be able to login again if re-enabled. // Cannot be enabled if no external identity provider connectors are configured. LocalAuthDisabled bool + // StaticConnectors are additional connectors to seed during initialization + StaticConnectors []dex.Connector } // EmbeddedStorageConfig holds storage configuration for the embedded IdP. type EmbeddedStorageConfig struct { - // Type is the storage type (currently only "sqlite3" is supported) + // Type is the storage type: "sqlite3" (default) or "postgres" Type string // Config contains type-specific configuration Config EmbeddedStorageTypeConfig @@ -62,6 +65,8 @@ type EmbeddedStorageConfig struct { type EmbeddedStorageTypeConfig struct { // File is the path to the SQLite database file (for sqlite3 type) File string + // DSN is the connection string for postgres + DSN string } // OwnerConfig represents the initial owner/admin user for the embedded IdP. @@ -74,6 +79,22 @@ type OwnerConfig struct { Username string } +// buildIdpStorageConfig builds the Dex storage config map based on the storage type. +func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) { + switch storageType { + case "sqlite3": + return map[string]interface{}{ + "file": cfg.File, + }, nil + case "postgres": + return map[string]interface{}{ + "dsn": cfg.DSN, + }, nil + default: + return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType) + } +} + // ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig. func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { if c.Issuer == "" { @@ -85,6 +106,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" { return nil, fmt.Errorf("storage file is required for sqlite3") } + if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" { + return nil, fmt.Errorf("storage DSN is required for postgres") + } + + storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config) + if err != nil { + return nil, fmt.Errorf("invalid IdP storage config: %w", err) + } // Build CLI redirect URIs including the device callback (both relative and absolute) cliRedirectURIs := c.CLIRedirectURIs @@ -100,10 +129,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { cfg := &dex.YAMLConfig{ Issuer: c.Issuer, Storage: dex.Storage{ - Type: c.Storage.Type, - Config: map[string]interface{}{ - "file": c.Storage.Config.File, - }, + Type: c.Storage.Type, + Config: storageConfig, }, Web: dex.Web{ AllowedOrigins: []string{"*"}, @@ -133,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { RedirectURIs: cliRedirectURIs, }, }, + StaticConnectors: c.StaticConnectors, } // Add owner user if provided @@ -169,6 +197,9 @@ type OAuthConfigProvider interface { // Management server has embedded Dex and can validate tokens via localhost, // avoiding external network calls and DNS resolution issues during startup. GetLocalKeysLocation() string + // GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage, + // or nil if direct key fetching is not supported (falls back to HTTP). + GetKeyFetcher() nbjwt.KeyFetcher GetClientIDs() []string GetUserIDClaim() string GetTokenEndpoint() string @@ -569,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string { return m.config.CLIRedirectURIs } +// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage. +func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher { + return m.provider.GetJWKS +} + // GetKeysLocation returns the JWKS endpoint URL for token validation. func (m *EmbeddedIdPManager) GetKeysLocation() string { return m.provider.GetKeysLocation() diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 48e4f3000..dadbfd83e 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) + credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey) if err != nil { return nil, err } service, err := admin.NewService(context.Background(), option.WithScopes(admin.AdminDirectoryUserReadonlyScope), - option.WithCredentials(adminCredentials), + credentialsOption, ) if err != nil { return nil, err @@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e return nil } -// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. -// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. -// If that fails, it falls back to using the default Google credentials path. -// It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { +// getGoogleCredentialsOption returns the google.golang.org/api option carrying +// Google credentials derived from the provided serviceAccountKey. +// It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON. +// If the key is empty, it falls back to the default Google credentials path. +func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) { log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) } - creds, err := google.CredentialsFromJSON( - context.Background(), - decodeKey, - admin.AdminDirectoryUserReadonlyScope, - ) - if err == nil { - // No need to fallback to the default Google credentials path - return creds, nil + if len(decodeKey) > 0 { + return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil } - log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.WithContext(ctx).Debug("falling back to default google credentials location") + log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location") - creds, err = google.FindDefaultCredentials( - context.Background(), + creds, err := google.FindDefaultCredentials( + ctx, admin.AdminDirectoryUserReadonlyScope, ) if err != nil { return nil, err } - return creds, nil + return option.WithCredentials(creds), nil } // parseGoogleWorkspaceUser parse google user to UserData. diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 28e3d81f9..20d6cacd5 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr case "jumpcloud": return NewJumpCloudManager(JumpCloudClientConfig{ APIToken: config.ExtraConfig["ApiToken"], + ApiUrl: config.ExtraConfig["ApiUrl"], }, appMetrics) case "pocketid": return NewPocketIdManager(PocketIdClientConfig{ diff --git a/management/server/idp/jumpcloud.go b/management/server/idp/jumpcloud.go index 8c4a9d089..f0dec3a9b 100644 --- a/management/server/idp/jumpcloud.go +++ b/management/server/idp/jumpcloud.go @@ -1,24 +1,40 @@ package idp import ( + "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "strings" - v1 "github.com/TheJumpCloud/jcapi-go/v1" - "github.com/netbirdio/netbird/management/server/telemetry" ) const ( - contentType = "application/json" - accept = "application/json" + jumpCloudDefaultApiUrl = "https://console.jumpcloud.com" + jumpCloudSearchPageSize = 100 ) +// jumpCloudUser represents a JumpCloud V1 API system user. +type jumpCloudUser struct { + ID string `json:"_id"` + Email string `json:"email"` + Firstname string `json:"firstname"` + Middlename string `json:"middlename"` + Lastname string `json:"lastname"` +} + +// jumpCloudUserList represents the response from the JumpCloud search endpoint. +type jumpCloudUserList struct { + Results []jumpCloudUser `json:"results"` + TotalCount int `json:"totalCount"` +} + // JumpCloudManager JumpCloud manager client instance. type JumpCloudManager struct { - client *v1.APIClient + apiBase string apiToken string httpClient ManagerHTTPClient credentials ManagerCredentials @@ -29,6 +45,7 @@ type JumpCloudManager struct { // JumpCloudClientConfig JumpCloud manager client configurations. type JumpCloudClientConfig struct { APIToken string + ApiUrl string } // JumpCloudCredentials JumpCloud authentication information. @@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing") } - client := v1.NewAPIClient(v1.NewConfiguration()) + apiBase := config.ApiUrl + if apiBase == "" { + apiBase = jumpCloudDefaultApiUrl + } + apiBase = strings.TrimSuffix(apiBase, "/") + if !strings.HasSuffix(apiBase, "/api") { + apiBase += "/api" + } + credentials := &JumpCloudCredentials{ clientConfig: config, httpClient: httpClient, @@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM } return &JumpCloudManager{ - client: client, + apiBase: apiBase, apiToken: config.APIToken, httpClient: httpClient, credentials: credentials, @@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error return JWTToken{}, nil } -func (jm *JumpCloudManager) authenticationContext() context.Context { - return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{ - Key: jm.apiToken, - }) -} - -// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { - return nil -} - -// GetUserDataByID requests user data from JumpCloud via ID. -func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { - authCtx := jm.authenticationContext() - user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) +// doRequest executes an HTTP request against the JumpCloud V1 API. +func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) { + reqURL := jm.apiBase + path + req, err := http.NewRequestWithContext(ctx, method, reqURL, body) if err != nil { return nil, err } + + req.Header.Set("x-api-key", jm.apiToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := jm.httpClient.Do(req) + if err != nil { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) + return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +// GetUserDataByID requests user data from JumpCloud via ID. +func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil) + if err != nil { + return nil, err } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetUserDataByID() } + var user jumpCloudUser + if err = jm.helper.Unmarshal(body, &user); err != nil { + return nil, err + } + userData := parseJumpCloudUser(user) userData.AppMetadata = appMetadata @@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap } // GetAccount returns all the users for a given profile. -func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) +func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + allUsers, err := jm.searchAllUsers(ctx) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetAccount() } - users := make([]*UserData, 0) - for _, user := range userList.Results { + users := make([]*UserData, 0, len(allUsers)) + for _, user := range allUsers { userData := parseJumpCloudUser(user) userData.AppMetadata.WTAccountID = accountID - users = append(users, userData) } @@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]* // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) +func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + allUsers, err := jm.searchAllUsers(ctx) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetAllAccounts() } indexedUsers := make(map[string][]*UserData) - for _, user := range userList.Results { + for _, user := range allUsers { userData := parseJumpCloudUser(user) indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } @@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use return indexedUsers, nil } +// searchAllUsers paginates through all system users using limit/skip. +func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) { + var allUsers []jumpCloudUser + + for skip := 0; ; skip += jumpCloudSearchPageSize { + searchReq := map[string]int{ + "limit": jumpCloudSearchPageSize, + "skip": skip, + } + + payload, err := json.Marshal(searchReq) + if err != nil { + return nil, err + } + + body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload)) + if err != nil { + return nil, err + } + + var userList jumpCloudUserList + if err = jm.helper.Unmarshal(body, &userList); err != nil { + return nil, err + } + + allUsers = append(allUsers, userList.Results...) + + if skip+len(userList.Results) >= userList.TotalCount { + break + } + } + + return allUsers, nil +} + // CreateUser creates a new user in JumpCloud Idp and sends an invitation. func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") @@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { searchFilter := map[string]interface{}{ "searchFilter": map[string]interface{}{ "filter": []string{email}, @@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]* }, } - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter) + payload, err := json.Marshal(searchFilter) if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode) + body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload)) + if err != nil { + return nil, err } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetUserByEmail() } - usersData := make([]*UserData, 0) + var userList jumpCloudUserList + if err = jm.helper.Unmarshal(body, &userList); err != nil { + return nil, err + } + + usersData := make([]*UserData, 0, len(userList.Results)) for _, user := range userList.Results { usersData = append(usersData, parseJumpCloudUser(user)) } @@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error { } // DeleteUser from jumpCloud directory -func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { - authCtx := jm.authenticationContext() - _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) +func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error { + _, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil) if err != nil { return err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountDeleteUser() @@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { } // parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData. -func parseJumpCloudUser(user v1.Systemuserreturn) *UserData { +func parseJumpCloudUser(user jumpCloudUser) *UserData { names := []string{user.Firstname, user.Middlename, user.Lastname} return &UserData{ Email: user.Email, Name: strings.Join(names, " "), - ID: user.Id, + ID: user.ID, } } diff --git a/management/server/idp/jumpcloud_test.go b/management/server/idp/jumpcloud_test.go index 1bfdcefcc..dc7a9cb6c 100644 --- a/management/server/idp/jumpcloud_test.go +++ b/management/server/idp/jumpcloud_test.go @@ -1,8 +1,15 @@ package idp import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/telemetry" @@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) { }) } } + +func TestJumpCloudGetUserDataByID(t *testing.T) { + userResponse := jumpCloudUser{ + ID: "user123", + Email: "test@example.com", + Firstname: "John", + Middlename: "", + Lastname: "Doe", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/systemusers/user123", r.URL.Path) + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "test-api-key", r.Header.Get("x-api-key")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(userResponse) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"}) + require.NoError(t, err) + + assert.Equal(t, "user123", userData.ID) + assert.Equal(t, "test@example.com", userData.Email) + assert.Equal(t, "John Doe", userData.Name) + assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID) +} + +func TestJumpCloudGetAccount(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/search/systemusers", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + var reqBody map[string]any + assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + assert.Contains(t, reqBody, "limit") + assert.Contains(t, reqBody, "skip") + + resp := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"}, + {ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"}, + }, + TotalCount: 2, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + users, err := manager.GetAccount(context.Background(), "testAccount") + require.NoError(t, err) + assert.Len(t, users, 2) + assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID) +} + +func TestJumpCloudGetAllAccounts(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "a@test.com", Firstname: "Alice"}, + {ID: "u2", Email: "b@test.com", Firstname: "Bob"}, + }, + TotalCount: 2, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + indexedUsers, err := manager.GetAllAccounts(context.Background()) + require.NoError(t, err) + assert.Len(t, indexedUsers[UnsetAccountID], 2) +} + +func TestJumpCloudGetAllAccountsPagination(t *testing.T) { + totalUsers := 250 + allUsers := make([]jumpCloudUser, totalUsers) + for i := range allUsers { + allUsers[i] = jumpCloudUser{ + ID: fmt.Sprintf("u%d", i), + Email: fmt.Sprintf("user%d@test.com", i), + Firstname: fmt.Sprintf("User%d", i), + } + } + + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]int + assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + + limit := reqBody["limit"] + skip := reqBody["skip"] + requestCount++ + + end := skip + limit + if end > totalUsers { + end = totalUsers + } + + resp := jumpCloudUserList{ + Results: allUsers[skip:end], + TotalCount: totalUsers, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + indexedUsers, err := manager.GetAllAccounts(context.Background()) + require.NoError(t, err) + assert.Len(t, indexedUsers[UnsetAccountID], totalUsers) + assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100") +} + +func TestJumpCloudGetUserByEmail(t *testing.T) { + searchResponse := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"}, + }, + TotalCount: 1, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/search/systemusers", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Contains(t, string(body), "alice@test.com") + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(searchResponse) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + users, err := manager.GetUserByEmail(context.Background(), "alice@test.com") + require.NoError(t, err) + assert.Len(t, users, 1) + assert.Equal(t, "alice@test.com", users[0].Email) +} + +func TestJumpCloudDeleteUser(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/systemusers/user123", r.URL.Path) + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "test-api-key", r.Header.Get("x-api-key")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"}) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + err := manager.DeleteUser(context.Background(), "user123") + require.NoError(t, err) +} + +func TestJumpCloudAPIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + _, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +func TestParseJumpCloudUser(t *testing.T) { + user := jumpCloudUser{ + ID: "abc123", + Email: "test@example.com", + Firstname: "John", + Middlename: "M", + Lastname: "Doe", + } + + userData := parseJumpCloudUser(user) + assert.Equal(t, "abc123", userData.ID) + assert.Equal(t, "test@example.com", userData.Email) + assert.Equal(t, "John M Doe", userData.Name) +} + +func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager { + t.Helper() + return &JumpCloudManager{ + apiBase: apiBase, + apiToken: "test-api-key", + httpClient: http.DefaultClient, + helper: JsonParser{}, + appMetrics: nil, + } +} diff --git a/management/server/idp/migration/migration.go b/management/server/idp/migration/migration.go new file mode 100644 index 000000000..01cadb86d --- /dev/null +++ b/management/server/idp/migration/migration.go @@ -0,0 +1,235 @@ +// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0 +// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later. +// It includes functions to seed connectors and migrate existing users to use these connectors. +package migration + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/types" +) + +// Server is the dependency interface that migration functions use to access +// the main data store and the activity event store. +type Server interface { + Store() Store + EventStore() EventStore // may return nil +} + +const idpSeedInfoKey = "IDP_SEED_INFO" +const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN" + +func isDryRun() bool { + return os.Getenv(dryRunEnvKey) == "true" +} + +var ErrNoSeedInfo = errors.New("no seed info found in environment") + +// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it, +// and JSON-unmarshals it into a dex.Connector. Returns nil if not set. +func SeedConnectorFromEnv() (*dex.Connector, error) { + val, ok := os.LookupEnv(idpSeedInfoKey) + if !ok || val == "" { + return nil, ErrNoSeedInfo + } + + decoded, err := base64.StdEncoding.DecodeString(val) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + + var conn dex.Connector + if err := json.Unmarshal(decoded, &conn); err != nil { + return nil, fmt.Errorf("json unmarshal: %w", err) + } + + return &conn, nil +} + +// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and +// the activity store, if present) so that it encodes the given connector ID, +// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true +// to log what would happen without writing any changes. +func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error { + ctx := context.Background() + + if isDryRun() { + log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written") + } + + users, err := s.Store().ListUsers(ctx) + if err != nil { + return fmt.Errorf("failed to list users: %w", err) + } + + // Reconciliation pass: fix activity store for users already migrated in main DB + // but whose activity references may still use old IDs (from a previous partial failure). + if s.EventStore() != nil && !isDryRun() { + if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil { + return err + } + } + + var migratedCount, skippedCount int + + for _, user := range users { + _, _, decErr := dex.DecodeDexUserID(user.Id) + if decErr == nil { + skippedCount++ + continue + } + + newUserID := dex.EncodeDexUserID(user.Id, conn.ID) + + if isDryRun() { + log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID) + migratedCount++ + continue + } + + if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil { + return err + } + + migratedCount++ + } + + if isDryRun() { + log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount) + } else { + log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount) + } + + return nil +} + +// reconcileActivityStore updates activity store references for users already migrated +// in the main DB whose activity entries may still use old IDs from a previous partial failure. +func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error { + for _, user := range users { + originalID, _, err := dex.DecodeDexUserID(user.Id) + if err != nil { + // skip users that aren't migrated, they will be handled in the main migration loop + continue + } + if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil { + return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err) + } + } + return nil +} + +// migrateUser updates a single user's ID in both the main store and the activity store. +func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error { + if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil { + return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err) + } + + if s.EventStore() == nil { + return nil + } + + if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil { + return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err) + } + + return nil +} + +// PopulateUserInfo fetches user email and name from the external IDP and updates +// the store for users that are missing this information. +func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error { + ctx := context.Background() + + users, err := s.Store().ListUsers(ctx) + if err != nil { + return fmt.Errorf("failed to list users: %w", err) + } + + // Build a map of IDP user ID -> UserData from the external IDP + allAccounts, err := idpManager.GetAllAccounts(ctx) + if err != nil { + return fmt.Errorf("failed to fetch accounts from IDP: %w", err) + } + + idpUsers := make(map[string]*idp.UserData) + for _, accountUsers := range allAccounts { + for _, userData := range accountUsers { + idpUsers[userData.ID] = userData + } + } + + log.Infof("fetched %d users from IDP", len(idpUsers)) + + var updatedCount, skippedCount, notFoundCount int + + for _, user := range users { + if user.IsServiceUser { + skippedCount++ + continue + } + + if user.Email != "" && user.Name != "" { + skippedCount++ + continue + } + + // The user ID in the store may be the original IDP ID or a Dex-encoded ID. + // Try to decode the Dex format first to get the original IDP ID. + lookupID := user.Id + if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil { + lookupID = originalID + } + + idpUser, found := idpUsers[lookupID] + if !found { + notFoundCount++ + log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID) + continue + } + + email := user.Email + name := user.Name + if email == "" && idpUser.Email != "" { + email = idpUser.Email + } + if name == "" && idpUser.Name != "" { + name = idpUser.Name + } + + if email == user.Email && name == user.Name { + skippedCount++ + continue + } + + if dryRun { + log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name) + updatedCount++ + continue + } + + if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil { + return fmt.Errorf("failed to update user info for %s: %w", user.Id, err) + } + + log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name) + updatedCount++ + } + + if dryRun { + log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount) + } else { + log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount) + } + + return nil +} diff --git a/management/server/idp/migration/migration_test.go b/management/server/idp/migration/migration_test.go new file mode 100644 index 000000000..2ff71347e --- /dev/null +++ b/management/server/idp/migration/migration_test.go @@ -0,0 +1,828 @@ +package migration + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/types" +) + +// testStore is a hand-written mock for MigrationStore. +type testStore struct { + listUsersFunc func(ctx context.Context) ([]*types.User, error) + updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error + updateUserInfoFunc func(ctx context.Context, userID, email, name string) error + checkSchemaFunc func(checks []SchemaCheck) []SchemaError + updateCalls []updateUserIDCall + updateInfoCalls []updateUserInfoCall +} + +type updateUserIDCall struct { + AccountID string + OldUserID string + NewUserID string +} + +type updateUserInfoCall struct { + UserID string + Email string + Name string +} + +func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) { + return s.listUsersFunc(ctx) +} + +func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error { + s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID}) + return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID) +} + +func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error { + s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name}) + if s.updateUserInfoFunc != nil { + return s.updateUserInfoFunc(ctx, userID, email, name) + } + return nil +} + +func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError { + if s.checkSchemaFunc != nil { + return s.checkSchemaFunc(checks) + } + return nil +} + +type testServer struct { + store Store + eventStore EventStore +} + +func (s *testServer) Store() Store { return s.store } +func (s *testServer) EventStore() EventStore { return s.eventStore } + +func TestSeedConnectorFromEnv(t *testing.T) { + t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) { + os.Unsetenv(idpSeedInfoKey) + + conn, err := SeedConnectorFromEnv() + assert.ErrorIs(t, err, ErrNoSeedInfo) + assert.Nil(t, conn) + }) + + t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) { + t.Setenv(idpSeedInfoKey, "") + + conn, err := SeedConnectorFromEnv() + assert.ErrorIs(t, err, ErrNoSeedInfo) + assert.Nil(t, conn) + }) + + t.Run("returns error on invalid base64", func(t *testing.T) { + t.Setenv(idpSeedInfoKey, "not-valid-base64!!!") + + conn, err := SeedConnectorFromEnv() + assert.NotErrorIs(t, err, ErrNoSeedInfo) + assert.Error(t, err) + assert.Nil(t, conn) + assert.Contains(t, err.Error(), "base64 decode") + }) + + t.Run("returns error on invalid JSON", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not json")) + t.Setenv(idpSeedInfoKey, encoded) + + conn, err := SeedConnectorFromEnv() + assert.NotErrorIs(t, err, ErrNoSeedInfo) + assert.Error(t, err) + assert.Nil(t, conn) + assert.Contains(t, err.Error(), "json unmarshal") + }) + + t.Run("successfully decodes valid connector", func(t *testing.T) { + expected := dex.Connector{ + Type: "oidc", + Name: "Test Provider", + ID: "test-provider", + Config: map[string]any{ + "issuer": "https://example.com", + "clientID": "my-client-id", + "clientSecret": "my-secret", + }, + } + + data, err := json.Marshal(expected) + require.NoError(t, err) + + encoded := base64.StdEncoding.EncodeToString(data) + t.Setenv(idpSeedInfoKey, encoded) + + conn, err := SeedConnectorFromEnv() + assert.NoError(t, err) + require.NotNil(t, conn) + assert.Equal(t, expected.Type, conn.Type) + assert.Equal(t, expected.Name, conn.Name) + assert.Equal(t, expected.ID, conn.ID) + assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"]) + }) +} + +func TestMigrateUsersToStaticConnectors(t *testing.T) { + connector := &dex.Connector{ + Type: "oidc", + Name: "Test Provider", + ID: "test-connector", + } + + t.Run("succeeds with no users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + }) + + t.Run("returns error when ListUsers fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return nil, fmt.Errorf("db error") + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to list users") + }) + + t.Run("migrates single user with correct encoded ID", func(t *testing.T) { + user := &types.User{Id: "user-1", AccountID: "account-1"} + expectedNewID := dex.EncodeDexUserID("user-1", "test-connector") + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{user}, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + require.Len(t, ms.updateCalls, 1) + assert.Equal(t, "account-1", ms.updateCalls[0].AccountID) + assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID) + assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID) + }) + + t.Run("migrates multiple users", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + {Id: "user-3", AccountID: "account-2"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 3) + }) + + t.Run("returns error when UpdateUserID fails", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + callCount := 0 + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + callCount++ + if callCount == 2 { + return fmt.Errorf("update failed") + } + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to update user ID for user user-2") + }) + + t.Run("stops on first UpdateUserID error", func(t *testing.T) { + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return fmt.Errorf("update failed") + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.Error(t, err) + assert.Len(t, ms.updateCalls, 1) // stopped after first error + }) + + t.Run("skips already migrated users", func(t *testing.T) { + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 0) + }) + + t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) { + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + {Id: "user-3", AccountID: "account-2"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + // Only user-2 and user-3 should be migrated + assert.Len(t, ms.updateCalls, 2) + assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID) + assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID) + }) + + t.Run("dry run does not call UpdateUserID", func(t *testing.T) { + t.Setenv(dryRunEnvKey, "true") + + users := []*types.User{ + {Id: "user-1", AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + t.Fatal("UpdateUserID should not be called in dry-run mode") + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 0) + }) + + t.Run("dry run skips already migrated users", func(t *testing.T) { + t.Setenv(dryRunEnvKey, "true") + + alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector") + users := []*types.User{ + {Id: alreadyMigratedID, AccountID: "account-1"}, + {Id: "user-2", AccountID: "account-1"}, + } + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return users, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + t.Fatal("UpdateUserID should not be called in dry-run mode") + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + }) + + t.Run("dry run disabled by default", func(t *testing.T) { + user := &types.User{Id: "user-1", AccountID: "account-1"} + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{user}, nil + }, + updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { + return nil + }, + } + + srv := &testServer{store: ms} + err := MigrateUsersToStaticConnectors(srv, connector) + assert.NoError(t, err) + assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run + }) +} + +func TestPopulateUserInfo(t *testing.T) { + noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil } + + t.Run("succeeds with no users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{}, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("returns error when ListUsers fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return nil, fmt.Errorf("db error") + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{} + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to list users") + }) + + t.Run("returns error when GetAllAccounts fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return nil, fmt.Errorf("idp error") + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch accounts from IDP") + }) + + t.Run("updates user with missing email and name", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "user1@example.com", Name: "User One"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID) + assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "User One", ms.updateInfoCalls[0].Name) + }) + + t.Run("updates only missing email when name exists", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name) + }) + + t.Run("updates only missing name when email exists", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name) + }) + + t.Run("skips users that already have both email and name", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips service users", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips users not found in IDP", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) { + dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector") + + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 1) + assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID) + assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "User", ms.updateInfoCalls[0].Name) + }) + + t.Run("handles multiple users across multiple accounts", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"}, + {Id: "user-3", AccountID: "acc-2", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "User 1"}, + {ID: "user-2", Email: "u2@example.com", Name: "User 2"}, + }, + "acc-2": { + {ID: "user-3", Email: "u3@example.com", Name: "User 3"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + require.Len(t, ms.updateInfoCalls, 2) + assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID) + assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email) + assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID) + assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email) + }) + + t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + return fmt.Errorf("db write error") + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to update user info for user-1") + }) + + t.Run("stops on first UpdateUserInfo error", func(t *testing.T) { + callCount := 0 + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + callCount++ + return fmt.Errorf("db write error") + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "U1"}, + {ID: "user-2", Email: "u2@example.com", Name: "U2"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.Error(t, err) + assert.Equal(t, 1, callCount) + }) + + t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + {Id: "user-2", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error { + t.Fatal("UpdateUserInfo should not be called in dry-run mode") + return nil + }, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": { + {ID: "user-1", Email: "u1@example.com", Name: "U1"}, + {ID: "user-2", Email: "u2@example.com", Name: "U2"}, + }, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, true) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) + + t.Run("skips user when IDP has empty email and name too", func(t *testing.T) { + ms := &testStore{ + listUsersFunc: func(ctx context.Context) ([]*types.User, error) { + return []*types.User{ + {Id: "user-1", AccountID: "acc-1", Email: "", Name: ""}, + }, nil + }, + updateUserIDFunc: noopUpdateID, + } + mockIDP := &idp.MockIDP{ + GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) { + return map[string][]*idp.UserData{ + "acc-1": {{ID: "user-1", Email: "", Name: ""}}, + }, nil + }, + } + + srv := &testServer{store: ms} + err := PopulateUserInfo(srv, mockIDP, false) + assert.NoError(t, err) + assert.Empty(t, ms.updateInfoCalls) + }) +} + +func TestSchemaError_String(t *testing.T) { + t.Run("missing table", func(t *testing.T) { + e := SchemaError{Table: "jobs"} + assert.Equal(t, `table "jobs" is missing`, e.String()) + }) + + t.Run("missing column", func(t *testing.T) { + e := SchemaError{Table: "users", Column: "email"} + assert.Equal(t, `column "email" on table "users" is missing`, e.String()) + }) +} + +func TestRequiredSchema(t *testing.T) { + // Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo. + expectedTables := []string{ + "users", + "personal_access_tokens", + "peers", + "accounts", + "user_invites", + "proxy_access_tokens", + "jobs", + } + + schemaTableNames := make([]string, len(RequiredSchema)) + for i, s := range RequiredSchema { + schemaTableNames[i] = s.Table + } + + for _, expected := range expectedTables { + assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected) + } +} + +func TestCheckSchema_MockStore(t *testing.T) { + t.Run("returns nil when all schema exists", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return nil + }, + } + errs := ms.CheckSchema(RequiredSchema) + assert.Empty(t, errs) + }) + + t.Run("returns errors for missing tables", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return []SchemaError{ + {Table: "jobs"}, + {Table: "proxy_access_tokens"}, + } + }, + } + errs := ms.CheckSchema(RequiredSchema) + require.Len(t, errs, 2) + assert.Equal(t, "jobs", errs[0].Table) + assert.Equal(t, "", errs[0].Column) + assert.Equal(t, "proxy_access_tokens", errs[1].Table) + }) + + t.Run("returns errors for missing columns", func(t *testing.T) { + ms := &testStore{ + checkSchemaFunc: func(checks []SchemaCheck) []SchemaError { + return []SchemaError{ + {Table: "users", Column: "email"}, + {Table: "users", Column: "name"}, + } + }, + } + errs := ms.CheckSchema(RequiredSchema) + require.Len(t, errs, 2) + assert.Equal(t, "users", errs[0].Table) + assert.Equal(t, "email", errs[0].Column) + }) +} diff --git a/management/server/idp/migration/store.go b/management/server/idp/migration/store.go new file mode 100644 index 000000000..e7cc54a41 --- /dev/null +++ b/management/server/idp/migration/store.go @@ -0,0 +1,82 @@ +package migration + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/management/server/types" +) + +// SchemaCheck represents a table and the columns required on it. +type SchemaCheck struct { + Table string + Columns []string +} + +// RequiredSchema lists all tables and columns that the migration tool needs. +// If any are missing, the user must upgrade their management server first so +// that the automatic GORM migrations create them. +var RequiredSchema = []SchemaCheck{ + {Table: "users", Columns: []string{"id", "email", "name", "account_id"}}, + {Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}}, + {Table: "peers", Columns: []string{"user_id"}}, + {Table: "accounts", Columns: []string{"created_by"}}, + {Table: "user_invites", Columns: []string{"created_by"}}, + {Table: "proxy_access_tokens", Columns: []string{"created_by"}}, + {Table: "jobs", Columns: []string{"triggered_by"}}, +} + +// SchemaError describes a single missing table or column. +type SchemaError struct { + Table string + Column string // empty when the whole table is missing +} + +func (e SchemaError) String() string { + if e.Column == "" { + return fmt.Sprintf("table %q is missing", e.Table) + } + return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table) +} + +// Store defines the data store operations required for IdP user migration. +// This interface is separate from the main store.Store interface because these methods +// are only used during one-time migration and should be removed once migration tooling +// is no longer needed. +// +// The SQL store implementations (SqlStore) already have these methods on their concrete +// types, so they satisfy this interface via Go's structural typing with zero code changes. +type Store interface { + // ListUsers returns all users across all accounts. + ListUsers(ctx context.Context) ([]*types.User, error) + + // UpdateUserID atomically updates a user's ID and all foreign key references + // across the database (peers, groups, policies, PATs, etc.). + UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error + + // UpdateUserInfo updates a user's email and name in the store. + UpdateUserInfo(ctx context.Context, userID, email, name string) error + + // CheckSchema verifies that all tables and columns required by the migration + // exist in the database. Returns a list of problems; an empty slice means OK. + CheckSchema(checks []SchemaCheck) []SchemaError +} + +// RequiredEventSchema lists all tables and columns that the migration tool needs +// in the activity/event store. +var RequiredEventSchema = []SchemaCheck{ + {Table: "events", Columns: []string{"initiator_id", "target_id"}}, + {Table: "deleted_users", Columns: []string{"id"}}, +} + +// EventStore defines the activity event store operations required for migration. +// Like Store, this is a temporary interface for migration tooling only. +type EventStore interface { + // CheckSchema verifies that all tables and columns required by the migration + // exist in the event database. Returns a list of problems; an empty slice means OK. + CheckSchema(checks []SchemaCheck) []SchemaError + + // UpdateUserID updates all event references (initiator_id, target_id) and + // deleted_users records to use the new user ID format. + UpdateUserID(ctx context.Context, oldUserID, newUserID string) error +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 19e3abdc0..9579d7a35 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -64,10 +64,19 @@ type Manager interface { GetVersionInfo(ctx context.Context) (*VersionInfo, error) } +type instanceStore interface { + GetAccountsCounter(ctx context.Context) (int64, error) +} + +type embeddedIdP interface { + CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) + GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error) +} + // DefaultManager is the default implementation of Manager. type DefaultManager struct { - store store.Store - embeddedIdpManager *idp.EmbeddedIdPManager + store instanceStore + embeddedIdpManager embeddedIdP setupRequired bool setupMu sync.RWMutex @@ -82,18 +91,18 @@ type DefaultManager struct { // NewManager creates a new instance manager. // If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults. func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) { - embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager) + embeddedIdp, ok := idpManager.(*idp.EmbeddedIdPManager) m := &DefaultManager{ - store: store, - embeddedIdpManager: embeddedIdp, - setupRequired: false, + store: store, + setupRequired: false, httpClient: &http.Client{ Timeout: httpTimeout, }, } - if embeddedIdp != nil { + if ok && embeddedIdp != nil { + m.embeddedIdpManager = embeddedIdp err := m.loadSetupRequired(ctx) if err != nil { return nil, err @@ -143,36 +152,61 @@ func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) { // CreateOwnerUser creates the initial owner user in the embedded IDP. func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if err := m.validateSetupInfo(email, password, name); err != nil { - return nil, err - } - if m.embeddedIdpManager == nil { return nil, errors.New("embedded IDP is not enabled") } - m.setupMu.RLock() - setupRequired := m.setupRequired - m.setupMu.RUnlock() + if err := m.validateSetupInfo(email, password, name); err != nil { + return nil, err + } - if !setupRequired { + m.setupMu.Lock() + defer m.setupMu.Unlock() + + if !m.setupRequired { return nil, status.Errorf(status.PreconditionFailed, "setup already completed") } + if err := m.checkSetupRequiredFromDB(ctx); err != nil { + var sErr *status.Error + if errors.As(err, &sErr) && sErr.Type() == status.PreconditionFailed { + m.setupRequired = false + } + return nil, err + } + userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) if err != nil { return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err) } - m.setupMu.Lock() m.setupRequired = false - m.setupMu.Unlock() log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email) return userData, nil } +func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error { + numAccounts, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return fmt.Errorf("failed to check accounts: %w", err) + } + if numAccounts > 0 { + return status.Errorf(status.PreconditionFailed, "setup already completed") + } + + users, err := m.embeddedIdpManager.GetAllAccounts(ctx) + if err != nil { + return fmt.Errorf("failed to check IdP users: %w", err) + } + if len(users) > 0 { + return status.Errorf(status.PreconditionFailed, "setup already completed") + } + + return nil +} + func (m *DefaultManager) validateSetupInfo(email, password, name string) error { if email == "" { return status.Errorf(status.InvalidArgument, "email is required") @@ -189,6 +223,9 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error { if len(password) < 8 { return status.Errorf(status.InvalidArgument, "password must be at least 8 characters") } + if len(password) > 72 { + return status.Errorf(status.InvalidArgument, "password must be at most 72 characters") + } return nil } diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go index 35d0ff53c..e3be9cfea 100644 --- a/management/server/instance/manager_test.go +++ b/management/server/instance/manager_test.go @@ -3,7 +3,12 @@ package instance import ( "context" "errors" + "fmt" + "net/http" + "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -11,173 +16,215 @@ import ( "github.com/netbirdio/netbird/management/server/idp" ) -// mockStore implements a minimal store.Store for testing +type mockIdP struct { + mu sync.Mutex + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) + users map[string][]*idp.UserData + getAllAccountsErr error +} + +func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createUserFunc != nil { + return m.createUserFunc(ctx, email, password, name) + } + return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil +} + +func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.getAllAccountsErr != nil { + return nil, m.getAllAccountsErr + } + return m.users, nil +} + type mockStore struct { accountsCount int64 err error } -func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) { +func (m *mockStore) GetAccountsCounter(_ context.Context) (int64, error) { if m.err != nil { return 0, m.err } return m.accountsCount, nil } -// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing -type mockEmbeddedIdPManager struct { - createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) -} - -func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if m.createUserFunc != nil { - return m.createUserFunc(ctx, email, password, name) +func newTestManager(idpMock *mockIdP, storeMock *mockStore) *DefaultManager { + return &DefaultManager{ + store: storeMock, + embeddedIdpManager: idpMock, + setupRequired: true, + httpClient: &http.Client{Timeout: httpTimeout}, } - return &idp.UserData{ - ID: "test-user-id", - Email: email, - Name: name, - }, nil -} - -// testManager is a test implementation that accepts our mock types -type testManager struct { - store *mockStore - embeddedIdpManager *mockEmbeddedIdPManager -} - -func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) { - if m.embeddedIdpManager == nil { - return false, nil - } - - count, err := m.store.GetAccountsCounter(ctx) - if err != nil { - return false, err - } - - return count == 0, nil -} - -func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { - if m.embeddedIdpManager == nil { - return nil, errors.New("embedded IDP is not enabled") - } - - return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) -} - -func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: nil, // No embedded IDP - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when embedded IDP is disabled") -} - -func TestIsSetupRequired_NoAccounts(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.True(t, required, "setup should be required when no accounts exist") -} - -func TestIsSetupRequired_AccountsExist(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 1}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when accounts exist") -} - -func TestIsSetupRequired_MultipleAccounts(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 5}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - required, err := manager.IsSetupRequired(context.Background()) - require.NoError(t, err) - assert.False(t, required, "setup should not be required when multiple accounts exist") -} - -func TestIsSetupRequired_StoreError(t *testing.T) { - manager := &testManager{ - store: &mockStore{err: errors.New("database error")}, - embeddedIdpManager: &mockEmbeddedIdPManager{}, - } - - _, err := manager.IsSetupRequired(context.Background()) - assert.Error(t, err, "should return error when store fails") } func TestCreateOwnerUser_Success(t *testing.T) { - expectedEmail := "admin@example.com" - expectedName := "Admin User" - expectedPassword := "securepassword123" + idpMock := &mockIdP{} + mgr := newTestManager(idpMock, &mockStore{}) - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{ - createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { - assert.Equal(t, expectedEmail, email) - assert.Equal(t, expectedPassword, password) - assert.Equal(t, expectedName, name) - return &idp.UserData{ - ID: "created-user-id", - Email: email, - Name: name, - }, nil - }, - }, - } - - userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName) + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") require.NoError(t, err) - assert.Equal(t, "created-user-id", userData.ID) - assert.Equal(t, expectedEmail, userData.Email) - assert.Equal(t, expectedName, userData.Name) + assert.Equal(t, "admin@example.com", userData.Email) + + _, err = mgr.CreateOwnerUser(context.Background(), "admin2@example.com", "password123", "Admin2") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_SetupAlreadyCompleted(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{}) + mgr.setupRequired = false + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") } func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: nil, - } + mgr := &DefaultManager{setupRequired: true} - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - assert.Error(t, err, "should return error when embedded IDP is disabled") + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) assert.Contains(t, err.Error(), "embedded IDP is not enabled") } func TestCreateOwnerUser_IdPError(t *testing.T) { - manager := &testManager{ - store: &mockStore{accountsCount: 0}, - embeddedIdpManager: &mockEmbeddedIdPManager{ - createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { - return nil, errors.New("user already exists") - }, + idpMock := &mockIdP{ + createUserFunc: func(_ context.Context, _, _, _ string) (*idp.UserData, error) { + return nil, errors.New("provider error") }, } + mgr := newTestManager(idpMock, &mockStore{}) - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - assert.Error(t, err, "should return error when IDP fails") + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "provider error") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after IdP error") +} + +func TestCreateOwnerUser_TransientDBError_DoesNotBlockSetup(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{err: errors.New("connection refused")}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "connection refused") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after transient DB error") + + mgr.store = &mockStore{} + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.NoError(t, err) + assert.Equal(t, "admin@example.com", userData.Email) +} + +func TestCreateOwnerUser_TransientIdPError_DoesNotBlockSetup(t *testing.T) { + idpMock := &mockIdP{getAllAccountsErr: errors.New("connection reset")} + mgr := newTestManager(idpMock, &mockStore{}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "connection reset") + + required, _ := mgr.IsSetupRequired(context.Background()) + assert.True(t, required, "setup should still be required after transient IdP error") + + idpMock.getAllAccountsErr = nil + userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.NoError(t, err) + assert.Equal(t, "admin@example.com", userData.Email) +} + +func TestCreateOwnerUser_DBCheckBlocksConcurrent(t *testing.T) { + idpMock := &mockIdP{ + users: map[string][]*idp.UserData{ + "acc1": {{ID: "existing-user"}}, + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_DBCheckBlocksWhenAccountsExist(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{accountsCount: 1}) + + _, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} + +func TestCreateOwnerUser_ConcurrentRequests(t *testing.T) { + var idpCallCount atomic.Int32 + var successCount atomic.Int32 + var failCount atomic.Int32 + + idpMock := &mockIdP{ + createUserFunc: func(_ context.Context, email, _, _ string) (*idp.UserData, error) { + idpCallCount.Add(1) + time.Sleep(50 * time.Millisecond) + return &idp.UserData{ID: "user-1", Email: email, Name: "Owner"}, nil + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, err := mgr.CreateOwnerUser( + context.Background(), + fmt.Sprintf("owner%d@example.com", idx), + "password1234", + fmt.Sprintf("Owner%d", idx), + ) + if err != nil { + failCount.Add(1) + } else { + successCount.Add(1) + } + }(i) + } + wg.Wait() + + assert.Equal(t, int32(1), successCount.Load(), "exactly one concurrent setup request should succeed") + assert.Equal(t, int32(9), failCount.Load(), "remaining concurrent requests should fail") + assert.Equal(t, int32(1), idpCallCount.Load(), "IdP CreateUser should be called exactly once") +} + +func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) { + mgr := &DefaultManager{} + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required) +} + +func TestIsSetupRequired_ReturnsFlag(t *testing.T) { + mgr := newTestManager(&mockIdP{}, &mockStore{}) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required) + + mgr.setupMu.Lock() + mgr.setupRequired = false + mgr.setupMu.Unlock() + + required, err = mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required) } func TestDefaultManager_ValidateSetupRequest(t *testing.T) { - manager := &DefaultManager{ - setupRequired: true, - } + manager := &DefaultManager{setupRequired: true} tests := []struct { name string @@ -188,11 +235,10 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { errorMsg string }{ { - name: "valid request", - email: "admin@example.com", - password: "password123", - userName: "Admin User", - expectError: false, + name: "valid request", + email: "admin@example.com", + password: "password123", + userName: "Admin User", }, { name: "empty email", @@ -235,11 +281,24 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { errorMsg: "password must be at least 8 characters", }, { - name: "password exactly 8 characters", + name: "password exactly 8 characters", + email: "admin@example.com", + password: "12345678", + userName: "Admin User", + }, + { + name: "password exactly 72 characters", + email: "admin@example.com", + password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiii", + userName: "Admin User", + }, + { + name: "password too long", email: "admin@example.com", - password: "12345678", + password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiiij", userName: "Admin User", - expectError: false, + expectError: true, + errorMsg: "password must be at most 72 characters", }, } @@ -255,14 +314,3 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) { }) } } - -func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) { - manager := &DefaultManager{ - setupRequired: false, - embeddedIdpManager: &idp.EmbeddedIdPManager{}, - } - - _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") - require.Error(t, err) - assert.Contains(t, err.Error(), "setup already completed") -} diff --git a/management/server/job/channel.go b/management/server/job/channel.go index c4dc98a68..c4454c4c9 100644 --- a/management/server/job/channel.go +++ b/management/server/job/channel.go @@ -28,7 +28,13 @@ func NewChannel() *Channel { return jc } -func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error { +func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) (err error) { + defer func() { + if r := recover(); r != nil { + err = ErrJobChannelClosed + } + }() + select { case <-ctx.Done(): return ctx.Err() diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 090c99877..4e6eb0a33 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -369,9 +370,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config requestBuffer := NewAccountRequestBuffer(ctx, store) ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", - eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { cleanup() diff --git a/management/server/management_test.go b/management/server/management_test.go index de02855bf..3ac28cd4a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -28,6 +28,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -207,6 +208,12 @@ func startServer( jobManager := job.NewJobManager(nil, str, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("failed creating cache store: %v", err) + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) @@ -227,7 +234,8 @@ func startServer( port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, - false) + false, + cacheStore) if err != nil { t.Fatalf("failed creating an account manager: %v", err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index f7a344fcd..8732cf89f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-version" "github.com/netbirdio/netbird/idp/dex" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/types" @@ -51,6 +52,7 @@ type properties map[string]interface{} type DataSource interface { GetAllAccounts(ctx context.Context) []*types.Account GetStoreEngine() types.Engine + GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) } // ConnManager peer connection manager that holds state for current active connections @@ -210,6 +212,17 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rosenpassEnabled int localUsers int idpUsers int + embeddedIdpTypes map[string]int + services int + servicesEnabled int + servicesTargets int + servicesStatusActive int + servicesStatusPending int + servicesStatusError int + servicesTargetType map[rpservice.TargetType]int + servicesAuthPassword int + servicesAuthPin int + servicesAuthOIDC int ) start := time.Now() metricsProperties := make(properties) @@ -218,10 +231,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rulesProtocol = make(map[string]int) rulesDirection = make(map[string]int) activeUsersLastDay = make(map[string]struct{}) + embeddedIdpTypes = make(map[string]int) + servicesTargetType = make(map[rpservice.TargetType]int) uptime = time.Since(w.startupTime).Seconds() connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() + customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx) + for _, account := range w.dataSource.GetAllAccounts(ctx) { accounts++ @@ -278,6 +295,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties { } else { idpUsers++ } + idpType := extractIdpType(idpID) + embeddedIdpTypes[idpType]++ } } } @@ -331,6 +350,37 @@ func (w *Worker) generateProperties(ctx context.Context) properties { peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion) } } + + for _, service := range account.Services { + services++ + if service.Enabled { + servicesEnabled++ + } + servicesTargets += len(service.Targets) + + switch rpservice.Status(service.Meta.Status) { + case rpservice.StatusActive: + servicesStatusActive++ + case rpservice.StatusPending: + servicesStatusPending++ + case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated: + servicesStatusError++ + } + + for _, target := range service.Targets { + servicesTargetType[target.TargetType]++ + } + + if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled { + servicesAuthPassword++ + } + if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled { + servicesAuthPin++ + } + if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { + servicesAuthOIDC++ + } + } } minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions) @@ -369,6 +419,27 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["rosenpass_enabled"] = rosenpassEnabled metricsProperties["local_users_count"] = localUsers metricsProperties["idp_users_count"] = idpUsers + metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes) + + metricsProperties["services"] = services + metricsProperties["services_enabled"] = servicesEnabled + metricsProperties["services_targets"] = servicesTargets + metricsProperties["services_status_active"] = servicesStatusActive + metricsProperties["services_status_pending"] = servicesStatusPending + metricsProperties["services_status_error"] = servicesStatusError + metricsProperties["services_auth_password"] = servicesAuthPassword + metricsProperties["services_auth_pin"] = servicesAuthPin + metricsProperties["services_auth_oidc"] = servicesAuthOIDC + metricsProperties["custom_domains"] = customDomains + metricsProperties["custom_domains_validated"] = customDomainsValidated + + for targetType, count := range servicesTargetType { + metricsProperties["services_target_type_"+string(targetType)] = count + } + + for idpType, count := range embeddedIdpTypes { + metricsProperties["embedded_idp_users_"+idpType] = count + } for protocol, count := range rulesProtocol { metricsProperties["rules_protocol_"+protocol] = count @@ -456,6 +527,20 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string) return req, cancel, nil } +// extractIdpType extracts the IdP type from a Dex connector ID. +// Connector IDs are formatted as "-" (e.g., "okta-abc123", "zitadel-xyz"). +// Returns the type prefix, or "oidc" if no known prefix is found. +func extractIdpType(connectorID string) string { + if connectorID == "local" { + return "local" + } + idx := strings.LastIndex(connectorID, "-") + if idx <= 0 { + return "oidc" + } + return strings.ToLower(connectorID[:idx]) +} + func getMinMaxVersion(inputList []string) (string, string) { versions := make([]*version.Version, 0) diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index d0ab45cd7..78f5c53be 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -6,6 +6,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/idp/dex" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -27,7 +28,8 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { // GetAllAccounts returns a list of *server.Account for use in tests with predefined information func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { localUserID := dex.EncodeDexUserID("10", "local") - idpUserID := dex.EncodeDexUserID("20", "zitadel") + idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0") + oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40") return []*types.Account{ { Id: "1", @@ -115,6 +117,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, }, }, + Services: []*rpservice.Service{ + { + ID: "svc1", + Enabled: true, + Targets: []*rpservice.Target{ + {TargetType: "peer"}, + {TargetType: "host"}, + }, + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true}, + }, + Meta: rpservice.Meta{Status: string(rpservice.StatusActive)}, + }, + { + ID: "svc2", + Enabled: false, + Targets: []*rpservice.Target{ + {TargetType: "domain"}, + }, + Auth: rpservice.AuthConfig{ + BearerAuth: &rpservice.BearerAuthConfig{Enabled: true}, + }, + Meta: rpservice.Meta{Status: string(rpservice.StatusPending)}, + }, + }, }, { Id: "2", @@ -180,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { "1": {}, }, }, + oidcUserID: { + Id: oidcUserID, + IsServiceUser: false, + PATs: map[string]*types.PersonalAccessToken{ + "1": {}, + }, + }, }, Networks: []*networkTypes.Network{ { @@ -215,6 +249,11 @@ func (mockDatasource) GetStoreEngine() types.Engine { return types.FileStoreEngine } +// GetCustomDomainsCounts returns test custom domain counts. +func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) { + return 3, 2, nil +} + // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties func TestGenerateProperties(t *testing.T) { ds := mockDatasource{} @@ -247,14 +286,14 @@ func TestGenerateProperties(t *testing.T) { if properties["rules"] != 4 { t.Errorf("expected 4 rules, got %d", properties["rules"]) } - if properties["users"] != 2 { - t.Errorf("expected 1 users, got %d", properties["users"]) + if properties["users"] != 3 { + t.Errorf("expected 3 users, got %d", properties["users"]) } if properties["setup_keys_usage"] != 2 { t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"]) } - if properties["pats"] != 4 { - t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"]) + if properties["pats"] != 5 { + t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"]) } if properties["peers_ssh_enabled"] != 2 { t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"]) @@ -338,7 +377,90 @@ func TestGenerateProperties(t *testing.T) { if properties["local_users_count"] != 1 { t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"]) } - if properties["idp_users_count"] != 1 { - t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"]) + if properties["idp_users_count"] != 2 { + t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"]) + } + if properties["embedded_idp_users_local"] != 1 { + t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"]) + } + if properties["embedded_idp_users_zitadel"] != 1 { + t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"]) + } + if properties["embedded_idp_users_oidc"] != 1 { + t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"]) + } + if properties["embedded_idp_count"] != 3 { + t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"]) + } + + if properties["services"] != 2 { + t.Errorf("expected 2 services, got %v", properties["services"]) + } + if properties["services_enabled"] != 1 { + t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"]) + } + if properties["services_targets"] != 3 { + t.Errorf("expected 3 services_targets, got %v", properties["services_targets"]) + } + if properties["services_status_active"] != 1 { + t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"]) + } + if properties["services_status_pending"] != 1 { + t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"]) + } + if properties["services_status_error"] != 0 { + t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"]) + } + if properties["services_target_type_peer"] != 1 { + t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"]) + } + if properties["services_target_type_host"] != 1 { + t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"]) + } + if properties["services_target_type_domain"] != 1 { + t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"]) + } + if properties["services_auth_password"] != 1 { + t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"]) + } + if properties["services_auth_oidc"] != 1 { + t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"]) + } + if properties["services_auth_pin"] != 0 { + t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"]) + } + if properties["custom_domains"] != int64(3) { + t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"]) + } + if properties["custom_domains_validated"] != int64(2) { + t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"]) + } +} + +func TestExtractIdpType(t *testing.T) { + tests := []struct { + connectorID string + expected string + }{ + {"okta-abc123def", "okta"}, + {"zitadel-d5uv82dra0haedlf6kv0", "zitadel"}, + {"entra-xyz789", "entra"}, + {"google-abc123", "google"}, + {"pocketid-abc123", "pocketid"}, + {"microsoft-abc123", "microsoft"}, + {"authentik-abc123", "authentik"}, + {"keycloak-d5uv82dra0haedlf6kv0", "keycloak"}, + {"local", "local"}, + {"d6jvvp69kmnc73c9pl40", "oidc"}, + {"", "oidc"}, + } + + for _, tt := range tests { + t.Run(tt.connectorID, func(t *testing.T) { + result := extractIdpType(tt.connectorID) + if result != tt.expected { + t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected) + } + }) } } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 29555ed0c..7a51cc200 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri return nil } +// hasForeignKey checks whether a foreign key constraint exists on the given table and column. +func hasForeignKey(db *gorm.DB, table, column string) bool { + var count int64 + + switch db.Name() { + case "postgres": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND kcu.table_name = ? + AND kcu.column_name = ? + `, table, column).Scan(&count) + case "mysql": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage + WHERE table_schema = DATABASE() + AND table_name = ? + AND column_name = ? + AND referenced_table_name IS NOT NULL + `, table, column).Scan(&count) + default: // sqlite + type fkInfo struct { + From string + } + var fks []fkInfo + db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks) + for _, fk := range fks { + if fk.From == column { + return true + } + } + return false + } + + return count > 0 +} + +// CleanupOrphanedResources deletes rows from the table of model T where the foreign +// key column (fkColumn) references a row in the table of model R that no longer exists. +func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error { + var model T + var refModel R + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model) + return nil + } + + if !db.Migrator().HasTable(&refModel) { + log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel) + return nil + } + + stmtT := &gorm.Statement{DB: db} + if err := stmtT.Parse(&model); err != nil { + return fmt.Errorf("parse model %T: %w", model, err) + } + childTable := stmtT.Schema.Table + + stmtR := &gorm.Statement{DB: db} + if err := stmtR.Parse(&refModel); err != nil { + return fmt.Errorf("parse reference model %T: %w", refModel, err) + } + parentTable := stmtR.Schema.Table + + if !db.Migrator().HasColumn(&model, fkColumn) { + log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable) + return nil + } + + // If a foreign key constraint already exists on the column, the DB itself + // enforces referential integrity and orphaned rows cannot exist. + if hasForeignKey(db, childTable, fkColumn) { + log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable) + return nil + } + + result := db.Exec( + fmt.Sprintf( + "DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)", + childTable, fkColumn, parentTable, + ), + ) + if result.Error != nil { + return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error) + } + + log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s", + result.RowsAffected, childTable, fkColumn, parentTable) + + return nil +} + func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error { if !db.Migrator().HasTable("peers") { log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup") diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index c1be8a3a3..5e00976c2 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) { err := migration.RemoveDuplicatePeerKeys(context.Background(), db) require.NoError(t, err, "Should not fail when table does not exist") } + +type testParent struct { + ID string `gorm:"primaryKey"` +} + +func (testParent) TableName() string { + return "test_parents" +} + +type testChild struct { + ID string `gorm:"primaryKey"` + ParentID string +} + +func (testChild) TableName() string { + return "test_children" +} + +type testChildWithFK struct { + ID string `gorm:"primaryKey"` + ParentID string `gorm:"index"` + Parent *testParent `gorm:"foreignKey:ParentID"` +} + +func (testChildWithFK) TableName() string { + return "test_children" +} + +func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB { + t.Helper() + db := setupDatabase(t) + for _, m := range models { + _ = db.Migrator().DropTable(m) + } + err := db.AutoMigrate(models...) + require.NoError(t, err, "Failed to auto-migrate tables") + return db +} + +func TestCleanupOrphanedResources_NoChildTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChild{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when child table does not exist") +} + +func TestCleanupOrphanedResources_NoParentTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testParent{}) + _ = db.Migrator().DropTable(&testChild{}) + + err := db.AutoMigrate(&testChild{}) + require.NoError(t, err) + + err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when parent table does not exist") +} + +func TestCleanupOrphanedResources_EmptyTables(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail on empty tables") + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count) +} + +func TestCleanupOrphanedResources_NoOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(2), count, "All children should remain when no orphans") +} + +func TestCleanupOrphanedResources_AllOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count, "All orphaned children should be deleted") +} + +func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var remaining []testChild + require.NoError(t, db.Order("id").Find(&remaining).Error) + + assert.Len(t, remaining, 3, "Only valid children should remain") + assert.Equal(t, "c1", remaining[0].ID) + assert.Equal(t, "c2", remaining[1].ID) + assert.Equal(t, "c3", remaining[2].ID) +} + +func TestCleanupOrphanedResources_Idempotent(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error) + + ctx := context.Background() + + err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count) + + err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count, "Count should remain the same after second run") +} + +func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) { + engine := os.Getenv("NETBIRD_STORE_ENGINE") + if engine != "postgres" && engine != "mysql" { + t.Skip("FK constraint early-exit test requires postgres or mysql") + } + + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChildWithFK{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := db.AutoMigrate(&testParent{}, &testChildWithFK{}) + require.NoError(t, err) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error) + + switch engine { + case "postgres": + require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID", + ).Error) + case "mysql": + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error) + require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id)", + ).Error) + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error) + } + + err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChildWithFK{}).Count(&count) + assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 032b1150f..ff369355e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -12,7 +12,7 @@ import ( "google.golang.org/grpc/status" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" @@ -46,7 +46,7 @@ type MockAccountManager struct { AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error @@ -148,7 +148,7 @@ type MockAccountManager struct { DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error } -func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { +func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) { // Mock implementation - no-op } @@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { - if am.GetGroupFunc != nil { - return am.GetGroupByNameFunc(ctx, accountID, groupName) +func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + if am.GetGroupByNameFunc != nil { + return am.GetGroupByNameFunc(ctx, groupName, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 90b4b9687..d10d4464f 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -17,6 +17,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -794,11 +795,17 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 843ca93e5..86f9b6579 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -7,7 +7,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" @@ -33,23 +33,23 @@ type Manager interface { } type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - groupsManager groups.Manager - accountManager account.Manager - reverseProxyManager reverseproxy.Manager + store store.Store + permissionsManager permissions.Manager + groupsManager groups.Manager + accountManager account.Manager + serviceManager service.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - groupsManager: groupsManager, - accountManager: accountManager, - reverseProxyManager: reverseproxyManager, + store: store, + permissionsManager: permissionsManager, + groupsManager: groupsManager, + accountManager: accountManager, + serviceManager: reverseproxyManager, } } @@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc // TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them go func() { - err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID) + err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID) if err != nil { log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err) } @@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } - serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID) + serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID) if err != nil { return fmt.Errorf("failed to check if resource is used by service: %w", err) } diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index 99de484e5..c6d8e7bcc 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -7,7 +7,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.NoError(t, err) @@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.Error(t, err) @@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) @@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.NoError(t, err) @@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.NoError(t, err) @@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index e90c61a97..1293a9934 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -21,11 +21,7 @@ type NetworkRouter struct { } func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { - if peer != "" && len(peerGroups) > 0 { - return nil, errors.New("peer and peerGroups cannot be set at the same time") - } - - return &NetworkRouter{ + r := &NetworkRouter{ ID: xid.New().String(), AccountID: accountID, NetworkID: networkID, @@ -34,7 +30,25 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup Masquerade: masquerade, Metric: metric, Enabled: enabled, - }, nil + } + + if err := r.Validate(); err != nil { + return nil, err + } + + return r, nil +} + +func (n *NetworkRouter) Validate() error { + if n.Peer != "" && len(n.PeerGroups) > 0 { + return errors.New("peer and peer_groups cannot be set at the same time") + } + + if n.Peer == "" && len(n.PeerGroups) == 0 { + return errors.New("either peer or peer_groups must be provided") + } + + return nil } func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go index 5801e3bfa..a2f2fe6e3 100644 --- a/management/server/networks/routers/types/router_test.go +++ b/management/server/networks/routers/types/router_test.go @@ -38,7 +38,7 @@ func TestNewNetworkRouter(t *testing.T) { expectedError: false, }, { - name: "Valid with no peer or peerGroups", + name: "Invalid with no peer or peerGroups", networkID: "network-3", accountID: "account-3", peer: "", @@ -46,7 +46,18 @@ func TestNewNetworkRouter(t *testing.T) { masquerade: true, metric: 300, enabled: true, - expectedError: false, + expectedError: true, + }, + { + name: "Invalid with empty peerGroups slice", + networkID: "network-5", + accountID: "account-5", + peer: "", + peerGroups: []string{}, + masquerade: true, + metric: 500, + enabled: true, + expectedError: true, }, // Invalid cases diff --git a/management/server/peer.go b/management/server/peer.go index a2ca97208..a02e34e0d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if err != nil { newLabel = "" } else { - _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name) + _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel) if err == nil { newLabel = "" } @@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var settings *types.Settings var eventsToStore []func() - serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID) + serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID) if err != nil { return fmt.Errorf("failed to check if resource is used by service: %w", err) } @@ -859,7 +859,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName } - am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if !temporary { + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + } if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) @@ -1480,9 +1482,11 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } } return peerDeletedEvents, nil diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 269b30822..db392ddda 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -352,9 +352,10 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) p.Name = a.Name p.Key = a.WgPubKey p.Meta = PeerSystemMeta{ - Hostname: a.Name, - GoOS: "js", - OS: "js", + Hostname: a.Name, + GoOS: "js", + OS: "js", + KernelVersion: "wasm", } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b17757ffd..6f8d924fd 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -32,11 +32,13 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" @@ -1293,11 +1295,15 @@ func Test_RegisterPeerByUser(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1379,11 +1385,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1533,11 +1543,15 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1614,11 +1628,15 @@ func Test_LoginPeer(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -2738,3 +2756,70 @@ func TestProcessPeerAddAuth(t *testing.T) { assert.Empty(t, config.GroupsToAdd) }) } + +func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + // Add first peer with hostname that produces DNS label "netbird1" + key1, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"}, + }, false) + require.NoError(t, err, "unable to add first peer") + assert.Equal(t, "netbird1", peer1.DNSLabel) + + // Add second peer with a different hostname + key2, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"}, + }, false) + require.NoError(t, err) + + update := peer2.Copy() + update.Name = "netbird1.demo.netbird.cloud" + updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update) + require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error") + assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name) + assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer") + assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback") +} + +func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key1, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"}, + }, false) + require.NoError(t, err) + assert.Equal(t, "web-server", peer1.DNSLabel) + + // Add second peer and rename it to a unique FQDN whose first label doesn't collide + key2, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"}, + }, false) + require.NoError(t, err) + + update := peer2.Copy() + update.Name = "api-server.example.com" + updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update) + require.NoError(t, err, "renaming to unique FQDN should succeed") + assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN") +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index ba901c771..9562487c0 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -84,7 +84,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) } diff --git a/management/server/route_test.go b/management/server/route_test.go index d4882eff8..91b2cf982 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -20,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1293,11 +1294,17 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index 8db37ec30..81185b020 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -269,3 +269,8 @@ func (s *FileStore) GetStoreEngine() types.Engine { func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) { // no-op: FileStore stores data in plaintext JSON; encryption is not supported } + +// GetCustomDomainsCounts is a no-op for FileStore as it doesn't support custom domains. +func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) { + return 0, 0, nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 018e54810..8189548b7 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -28,9 +28,11 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -131,8 +133,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{}, - &accesslogs.AccessLogEntry{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{}, + &accesslogs.AccessLogEntry{}, &proxy.Proxy{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -394,6 +396,11 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er return result.Error } + result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + result = tx.Select(clause.Associations).Delete(account) if result.Error != nil { return result.Error @@ -1007,6 +1014,18 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) { return count, nil } +// GetCustomDomainsCounts returns the total and validated custom domain counts. +func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) { + var total, validated int64 + if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil { + return 0, 0, err + } + if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil { + return 0, 0, err + } + return total, validated, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { var accounts []types.Account result := s.db.Find(&accounts) @@ -2063,10 +2082,11 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p return checks, nil } -func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) { const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, - pass_host_header, rewrite_redirects, session_private_key, session_public_key + pass_host_header, rewrite_redirects, session_private_key, session_public_key, + mode, listen_port, port_auto_assigned, source, source_peer, terminated FROM services WHERE account_id = $1` const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol, @@ -2078,11 +2098,14 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers return nil, err } - services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) { - var s reverseproxy.Service + services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) { + var s rpservice.Service var auth []byte var createdAt, certIssuedAt sql.NullTime var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString + var mode, source, sourcePeer sql.NullString + var terminated, portAutoAssigned sql.NullBool + var listenPort sql.NullInt64 err := row.Scan( &s.ID, &s.AccountID, @@ -2098,6 +2121,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers &s.RewriteRedirects, &sessionPrivateKey, &sessionPublicKey, + &mode, + &listenPort, + &portAutoAssigned, + &source, + &sourcePeer, + &terminated, ) if err != nil { return nil, err @@ -2109,12 +2138,13 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers } } - s.Meta = reverseproxy.ServiceMeta{} + s.Meta = rpservice.Meta{} if createdAt.Valid { s.Meta.CreatedAt = createdAt.Time } if certIssuedAt.Valid { - s.Meta.CertificateIssuedAt = certIssuedAt.Time + t := certIssuedAt.Time + s.Meta.CertificateIssuedAt = &t } if status.Valid { s.Meta.Status = status.String @@ -2128,8 +2158,25 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers if sessionPublicKey.Valid { s.SessionPublicKey = sessionPublicKey.String } - - s.Targets = []*reverseproxy.Target{} + if mode.Valid { + s.Mode = mode.String + } + if source.Valid { + s.Source = source.String + } + if sourcePeer.Valid { + s.SourcePeer = sourcePeer.String + } + if terminated.Valid { + s.Terminated = terminated.Bool + } + if portAutoAssigned.Valid { + s.PortAutoAssigned = portAutoAssigned.Bool + } + if listenPort.Valid { + s.ListenPort = uint16(listenPort.Int64) + } + s.Targets = []*rpservice.Target{} return &s, nil }) if err != nil { @@ -2141,7 +2188,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers } serviceIDs := make([]string, len(services)) - serviceMap := make(map[string]*reverseproxy.Service) + serviceMap := make(map[string]*rpservice.Service) for i, s := range services { serviceIDs[i] = s.ID serviceMap[s.ID] = s @@ -2152,8 +2199,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers return nil, err } - targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) { - var t reverseproxy.Target + targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) { + var t rpservice.Target var path sql.NullString err := row.Scan( &t.ID, @@ -2715,14 +2762,28 @@ func (s *SqlStore) GetStoreEngine() types.Engine { // NewSqliteStore creates a new SQLite store. func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { - storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) - if runtime.GOOS == "windows" { - // Vo avoid `The process cannot access the file because it is being used by another process` on Windows - storeStr = storeSqliteFileName + storeFile := storeSqliteFileName + if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" { + storeFile = envFile } - file := filepath.Join(dataDir, storeStr) - db, err := gorm.Open(sqlite.Open(file), getGormConfig()) + // Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc") + filePath, query, hasQuery := strings.Cut(storeFile, "?") + + connStr := filePath + if !filepath.IsAbs(filePath) { + connStr = filepath.Join(dataDir, filePath) + } + + // Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows + if hasQuery { + connStr += "?" + query + } else if runtime.GOOS != "windows" { + // To avoid `The process cannot access the file because it is being used by another process` on Windows + connStr += "?cache=shared" + } + + db, err := gorm.Open(sqlite.Open(connStr), getGormConfig()) if err != nil { return nil, err } @@ -4381,7 +4442,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error { // GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value. func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4400,7 +4461,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr // GetAllProxyAccessTokens retrieves all proxy access tokens. func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4416,7 +4477,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc // SaveProxyAccessToken saves a proxy access token to the database. func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error { - if result := s.db.WithContext(ctx).Create(token); result.Error != nil { + if result := s.db.Create(token); result.Error != nil { return status.Errorf(status.Internal, "save proxy access token: %v", result.Error) } return nil @@ -4424,7 +4485,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA // RevokeProxyAccessToken revokes a proxy access token by its ID. func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { - result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) + result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) if result.Error != nil { return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error) } @@ -4438,7 +4499,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { - result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}). + result := s.db.Model(&types.ProxyAccessToken{}). Where(idQueryCondition, tokenID). Update("last_used", time.Now().UTC()) if result.Error != nil { @@ -4825,7 +4886,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren return peerID, nil } -func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { +func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error { serviceCopy := service.Copy() if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt service data: %w", err) @@ -4839,16 +4900,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv return nil } -func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { +func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error { serviceCopy := service.Copy() if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt service data: %w", err) } + // Create target type instance outside transaction to avoid variable shadowing + targetType := &rpservice.Target{} + // Use a transaction to ensure atomic updates of the service and its targets err := s.db.Transaction(func(tx *gorm.DB) error { // Delete existing targets - if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil { + if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil { return err } @@ -4869,7 +4933,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv } func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error { - result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID) + result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete service from store") @@ -4882,13 +4946,53 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin return nil } -func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { +func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error { + result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete target from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "target not found for service %s", serviceID) + } + + return nil +} + +func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error { + result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete targets from store") + } + + return nil +} + +// GetTargetsByServiceID retrieves all targets for a given service +func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) { + var targets []*rpservice.Target + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get targets from store") + } + + return targets, nil +} + +func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var service *reverseproxy.Service + var service *rpservice.Service result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -4906,9 +5010,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren return service, nil } -func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { - var service *reverseproxy.Service - result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) +func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { + var service *rpservice.Service + result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain) @@ -4925,13 +5029,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str return service, nil } -func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { +func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var serviceList []*reverseproxy.Service + var serviceList []*rpservice.Service result := tx.Find(&serviceList) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) @@ -4947,13 +5051,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength return serviceList, nil } -func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { +func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var serviceList []*reverseproxy.Service + var serviceList []*rpservice.Service result := tx.Find(&serviceList, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) @@ -4969,6 +5073,130 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS return serviceList, nil } +// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service. +func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { + result := s.db.Model(&rpservice.Service{}). + Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral). + Update("meta_last_renewed_at", time.Now()) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error) + return status.Errorf(status.Internal, "renew ephemeral service") + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID) + } + return nil +} + +// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL. +// Only the fields needed for reaping are selected. The limit parameter caps the batch size to +// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to +// skip malformed legacy data. +func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) { + cutoff := time.Now().Add(-ttl) + var services []*rpservice.Service + result := s.db. + Select("id", "account_id", "source_peer", "domain"). + Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff). + Limit(limit). + Find(&services) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error) + return nil, status.Errorf(status.Internal, "get expired ephemeral services") + } + return services, nil +} + +// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer. +// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations. +// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to +// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*). +func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) { + if lockStrength == LockingStrengthNone { + var count int64 + result := s.db.Model(&rpservice.Service{}). + Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral). + Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error) + return 0, status.Errorf(status.Internal, "count ephemeral services") + } + return count, nil + } + + var ids []string + result := s.db.Model(&rpservice.Service{}). + Clauses(clause.Locking{Strength: string(lockStrength)}). + Select("id"). + Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral). + Pluck("id", &ids) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error) + return 0, status.Errorf(status.Internal, "count ephemeral services") + } + return int64(len(ids)), nil +} + +// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain. +// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations. +func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { + if lockStrength == LockingStrengthNone { + var count int64 + result := s.db.Model(&rpservice.Service{}). + Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error) + return false, status.Errorf(status.Internal, "check ephemeral service existence") + } + return count > 0, nil + } + + var id string + result := s.db.Model(&rpservice.Service{}). + Clauses(clause.Locking{Strength: string(lockStrength)}). + Select("id"). + Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Limit(1). + Pluck("id", &id) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error) + return false, status.Errorf(status.Internal, "check ephemeral service existence") + } + return id != "", nil +} + +// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. +func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster and port") + } + + return services, nil +} + +// GetServicesByCluster returns all services for the given proxy cluster. +func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster") + } + return services, nil +} + func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { tx := s.db @@ -5066,7 +5294,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin var logs []*accesslogs.AccessLogEntry var totalCount int64 - baseQuery := s.db.WithContext(ctx). + baseQuery := s.db. Model(&accesslogs.AccessLogEntry{}). Where(accountIDCondition, accountID) @@ -5077,7 +5305,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin return nil, 0, status.Errorf(status.Internal, "failed to count access logs") } - query := s.db.WithContext(ctx). + query := s.db. Where(accountIDCondition, accountID) query = s.applyAccessLogFilters(query, filter) @@ -5114,7 +5342,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin // DeleteOldAccessLogs deletes all access logs older than the specified time func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) { - result := s.db.WithContext(ctx). + result := s.db. Where("timestamp < ?", olderThan). Delete(&accesslogs.AccessLogEntry{}) @@ -5181,13 +5409,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces return query } -func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) { +func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var target *reverseproxy.Target + var target *rpservice.Target result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -5200,3 +5428,270 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength return target, nil } + +// SaveProxy saves or updates a proxy in the database +func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + result := s.db.Save(p) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) + return status.Errorf(status.Internal, "failed to save proxy") + } + return nil +} + +// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + now := time.Now() + + result := s.db. + Model(&proxy.Proxy{}). + Where("id = ? AND status = ?", proxyID, "connected"). + Update("last_seen", now) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error) + return status.Errorf(status.Internal, "failed to update proxy heartbeat") + } + + if result.RowsAffected == 0 { + p := &proxy.Proxy{ + ID: proxyID, + ClusterAddress: clusterAddress, + IPAddress: ipAddress, + LastSeen: now, + ConnectedAt: &now, + Status: "connected", + } + if err := s.db.Save(p).Error; err != nil { + log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) + return status.Errorf(status.Internal, "failed to create proxy on heartbeat") + } + } + + return nil +} + +// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + var addresses []string + + result := s.db. + Model(&proxy.Proxy{}). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses") + } + + return addresses, nil +} + +// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + var clusters []proxy.Cluster + + result := s.db.Model(&proxy.Proxy{}). + Select("cluster_address as address, COUNT(*) as connected_proxies"). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Group("cluster_address"). + Scan(&clusters) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error) + return nil, status.Errorf(status.Internal, "get active proxy clusters") + } + + return clusters, nil +} + +// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be +// considered active. Must be at least 2x the heartbeat interval (1 min). +const proxyActiveThreshold = 2 * time.Minute + +var validCapabilityColumns = map[string]struct{}{ + "supports_custom_ports": {}, + "require_subdomain": {}, + "supports_crowdsec": {}, +} + +// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports") +} + +// GetClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "require_subdomain") +} + +// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster +// have CrowdSec configured. Returns nil when no proxy reported the capability. +// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec +// requires unanimous support: a single unconfigured proxy would let requests +// bypass reputation checks. +func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec") +} + +// getClusterUnanimousCapability returns an aggregated boolean capability +// requiring all active proxies in the cluster to report true. +func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + Total int64 + Reported int64 + AllTrue bool + } + + // All active proxies must have reported the capability (no NULLs) and all + // must report true. A single unreported or false proxy means the cluster + // does not unanimously support the capability. + err := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Select("COUNT(*) AS total, "+ + "COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+ + "COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if result.Total == 0 || result.Reported == 0 { + return nil + } + + // If any proxy has not reported (NULL), we can't confirm unanimous support. + if result.Reported < result.Total { + v := false + return &v + } + + return &result.AllTrue +} + +// getClusterCapability returns an aggregated boolean capability for the given +// cluster. It checks active (connected, recently seen) proxies and returns: +// - *true if any proxy in the cluster has the capability set to true, +// - *false if at least one proxy reported but none set it to true, +// - nil if no proxy reported the capability at all. +func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + HasCapability bool + AnyTrue bool + } + + err := s.db. + Model(&proxy.Proxy{}). + Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+ + "COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if !result.HasCapability { + return nil + } + + return &result.AnyTrue +} + +// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration +func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { + cutoffTime := time.Now().Add(-inactivityDuration) + + result := s.db. + Where("last_seen < ?", cutoffTime). + Delete(&proxy.Proxy{}) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error) + return status.Errorf(status.Internal, "failed to cleanup stale proxies") + } + + if result.RowsAffected > 0 { + log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected) + } + + return nil +} + +// GetRoutingPeerNetworks returns the distinct network names where the peer is assigned as a routing peer +// in an enabled network router, either directly or via peer groups. +func (s *SqlStore) GetRoutingPeerNetworks(_ context.Context, accountID, peerID string) ([]string, error) { + var routers []*routerTypes.NetworkRouter + if err := s.db.Select("peer, peer_groups, network_id").Where("account_id = ? AND enabled = true", accountID).Find(&routers).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get enabled routers: %v", err) + } + + if len(routers) == 0 { + return nil, nil + } + + var groupPeers []types.GroupPeer + if err := s.db.Select("group_id").Where("account_id = ? AND peer_id = ?", accountID, peerID).Find(&groupPeers).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get peer group memberships: %v", err) + } + + groupSet := make(map[string]struct{}, len(groupPeers)) + for _, gp := range groupPeers { + groupSet[gp.GroupID] = struct{}{} + } + + networkIDs := make(map[string]struct{}) + for _, r := range routers { + if r.Peer == peerID { + networkIDs[r.NetworkID] = struct{}{} + } else if r.Peer == "" { + for _, pg := range r.PeerGroups { + if _, ok := groupSet[pg]; ok { + networkIDs[r.NetworkID] = struct{}{} + break + } + } + } + } + + if len(networkIDs) == 0 { + return nil, nil + } + + ids := make([]string, 0, len(networkIDs)) + for id := range networkIDs { + ids = append(ids, id) + } + + var networks []*networkTypes.Network + if err := s.db.Select("name").Where("account_id = ? AND id IN ?", accountID, ids).Find(&networks).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get networks: %v", err) + } + + names := make([]string, 0, len(networks)) + for _, n := range networks { + names = append(names, n.Name) + } + + return names, nil +} diff --git a/management/server/store/sql_store_idp_migration.go b/management/server/store/sql_store_idp_migration.go new file mode 100644 index 000000000..64962845b --- /dev/null +++ b/management/server/store/sql_store_idp_migration.go @@ -0,0 +1,177 @@ +package store + +// This file contains migration-only methods on SqlStore. +// They satisfy the migration.Store interface via duck typing. +// Delete this file when migration tooling is no longer needed. + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server/idp/migration" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError { + migrator := s.db.Migrator() + var errs []migration.SchemaError + + for _, check := range checks { + if !migrator.HasTable(check.Table) { + errs = append(errs, migration.SchemaError{Table: check.Table}) + continue + } + for _, col := range check.Columns { + if !migrator.HasColumn(check.Table, col) { + errs = append(errs, migration.SchemaError{Table: check.Table, Column: col}) + } + } + } + + return errs +} + +func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) { + tx := s.db + var users []*types.User + result := tx.Find(&users) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when listing users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue listing users from store") + } + + for _, user := range users { + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + } + + return users, nil +} + +// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction. +// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0). +func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error { + if s.storeEngine == types.SqliteStoreEngine { + return tx.Exec("PRAGMA defer_foreign_keys = ON").Error + } + + if s.storeEngine != types.PostgresStoreEngine { + return nil + } + + // GORM creates FK constraints as NOT DEFERRABLE by default, so + // SET CONSTRAINTS ALL DEFERRED is a no-op unless we ALTER them first. + err := tx.Exec(` + DO $$ DECLARE r RECORD; + BEGIN + FOR r IN SELECT conname, conrelid::regclass AS tbl + FROM pg_constraint WHERE contype = 'f' AND NOT condeferrable + LOOP + EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I DEFERRABLE INITIALLY IMMEDIATE', r.tbl, r.conname); + END LOOP; + END $$ + `).Error + if err != nil { + return fmt.Errorf("make FK constraints deferrable: %w", err) + } + return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error +} + +// txRestoreFKConstraints reverts FK constraints back to NOT DEFERRABLE after the +// deferred updates are done but before the transaction commits. +func (s *SqlStore) txRestoreFKConstraints(tx *gorm.DB) error { + if s.storeEngine != types.PostgresStoreEngine { + return nil + } + + return tx.Exec(` + DO $$ DECLARE r RECORD; + BEGIN + FOR r IN SELECT conname, conrelid::regclass AS tbl + FROM pg_constraint WHERE contype = 'f' AND condeferrable + LOOP + EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I NOT DEFERRABLE', r.tbl, r.conname); + END LOOP; + END $$ + `).Error +} + +func (s *SqlStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error { + user := &types.User{Email: email, Name: name} + if err := user.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt user info: %w", err) + } + + result := s.db.Model(&types.User{}).Where("id = ?", userID).Updates(map[string]any{ + "email": user.Email, + "name": user.Name, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("error updating user info for %s: %s", userID, result.Error) + return status.Errorf(status.Internal, "failed to update user info") + } + + return nil +} + +func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error { + type fkUpdate struct { + model any + column string + where string + } + + updates := []fkUpdate{ + {&types.PersonalAccessToken{}, "user_id", "user_id = ?"}, + {&types.PersonalAccessToken{}, "created_by", "created_by = ?"}, + {&nbpeer.Peer{}, "user_id", "user_id = ?"}, + {&types.UserInviteRecord{}, "created_by", "created_by = ?"}, + {&types.Account{}, "created_by", "created_by = ?"}, + {&types.ProxyAccessToken{}, "created_by", "created_by = ?"}, + {&types.Job{}, "triggered_by", "triggered_by = ?"}, + } + + log.Info("Updating user ID in the store") + err := s.transaction(func(tx *gorm.DB) error { + if err := s.txDeferFKConstraints(tx); err != nil { + return err + } + + for _, u := range updates { + if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil { + return fmt.Errorf("update %s: %w", u.column, err) + } + } + + if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil { + return fmt.Errorf("update users: %w", err) + } + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err) + return status.Errorf(status.Internal, "failed to update user ID in store") + } + + log.Info("Restoring FK constraints") + err = s.transaction(func(tx *gorm.DB) error { + if err := s.txRestoreFKConstraints(tx); err != nil { + return fmt.Errorf("restore FK constraints: %w", err) + } + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to restore FK constraints after user ID update: %s", err) + return status.Errorf(status.Internal, "failed to restore FK constraints after user ID update") + } + + return nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index bafa63580..8ea6c2ae5 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -350,6 +352,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { }, } + account.Services = []*rpservice.Service{ + { + ID: "service_id", + AccountID: account.Id, + Name: "test service", + Domain: "svc.example.com", + Enabled: true, + Targets: []*rpservice.Target{ + { + AccountID: account.Id, + ServiceID: "service_id", + Host: "localhost", + Port: 8080, + Protocol: "http", + Enabled: true, + }, + }, + }, + } + + account.Domains = []*proxydomain.Domain{ + { + ID: "domain_id", + Domain: "custom.example.com", + AccountID: account.Id, + Validated: true, + }, + } + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -411,6 +442,20 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") } + + domains, err := store.ListCustomDomains(context.Background(), account.Id) + require.NoError(t, err, "expecting no error after DeleteAccount when searching for custom domains") + require.Len(t, domains, 0, "expecting no custom domains to be found after DeleteAccount") + + var services []*rpservice.Service + err = store.(*SqlStore).db.Model(&rpservice.Service{}).Find(&services, "account_id = ?", account.Id).Error + require.NoError(t, err, "expecting no error after DeleteAccount when searching for services") + require.Len(t, services, 0, "expecting no services to be found after DeleteAccount") + + var targets []*rpservice.Target + err = store.(*SqlStore).db.Model(&rpservice.Target{}).Find(&targets, "account_id = ?", account.Id).Error + require.NoError(t, err, "expecting no error after DeleteAccount when searching for service targets") + require.Len(t, targets, 0, "expecting no service targets to be found after DeleteAccount") } func Test_GetAccount(t *testing.T) { diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index fa9a9dbf5..81c4b33ae 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -20,7 +20,8 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -264,7 +265,8 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, - &types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{}, + &types.AccountOnboarding{}, &service.Service{}, &service.Target{}, + &domain.Domain{}, } for i := len(models) - 1; i >= 0; i-- { diff --git a/management/server/store/store.go b/management/server/store/store.go index 2bc688a11..0d8b0678a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -25,9 +25,10 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/telemetry" @@ -120,7 +121,7 @@ type Store interface { GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error @@ -252,13 +253,20 @@ type Store interface { MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) - CreateService(ctx context.Context, service *reverseproxy.Service) error - UpdateService(ctx context.Context, service *reverseproxy.Service) error + CreateService(ctx context.Context, service *rpservice.Service) error + UpdateService(ctx context.Context, service *rpservice.Service) error DeleteService(ctx context.Context, accountID, serviceID string) error - GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) - GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) - GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) - GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) + GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) + GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) + GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) + GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) + + RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error + GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) + CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) + EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) + GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) + GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) @@ -270,7 +278,23 @@ type Store interface { CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) - GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) + GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) + GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) + DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error + DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error + + SaveProxy(ctx context.Context, proxy *proxy.Proxy) error + UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + + GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) + + GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) } const ( @@ -425,6 +449,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.RemoveDuplicatePeerKeys(ctx, db) }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id") + }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 79d275298..beee13d96 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -12,9 +12,10 @@ import ( gomock "github.com/golang/mock/gomock" dns "github.com/netbirdio/netbird/dns" - reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" zones "github.com/netbirdio/netbird/management/internals/modules/zones" records "github.com/netbirdio/netbird/management/internals/modules/zones/records" types "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -150,6 +151,33 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID) } +// CleanupStaleProxies mocks base method. +func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupStaleProxies indicates an expected call of CleanupStaleProxies. +func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) +} + +// GetClusterSupportsCrowdSec mocks base method. +func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. +func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) +} // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -193,6 +221,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain) } +// CountEphemeralServicesByPeer mocks base method. +func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer. +func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) +} + // CreateAccessLog mocks base method. func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { m.ctrl.T.Helper() @@ -293,7 +336,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C } // CreateService mocks base method. -func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { +func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateService", ctx, service) ret0, _ := ret[0].(error) @@ -559,6 +602,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID) } +// DeleteServiceTargets mocks base method. +func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteServiceTargets indicates an expected call of DeleteServiceTargets. +func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID) +} + // DeleteSetupKey mocks base method. func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { m.ctrl.T.Helper() @@ -573,6 +630,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID) } +// DeleteTarget mocks base method. +func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteTarget indicates an expected call of DeleteTarget. +func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID) +} + // DeleteTokenID2UserIDIndex mocks base method. func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error { m.ctrl.T.Helper() @@ -643,6 +714,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) } +// EphemeralServiceExists mocks base method. +func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EphemeralServiceExists indicates an expected call of EphemeralServiceExists. +func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain) +} + // ExecuteInTransaction mocks base method. func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error { m.ctrl.T.Helper() @@ -1095,10 +1181,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i } // GetAccountServices mocks base method. -func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { +func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID) - ret0, _ := ret[0].([]*reverseproxy.Service) + ret0, _ := ret[0].([]*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1199,6 +1285,36 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx) } +// GetActiveProxyClusterAddresses mocks base method. +func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) +} + +// GetActiveProxyClusters mocks base method. +func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) +} + // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() @@ -1258,6 +1374,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx) } +// GetClusterRequireSubdomain mocks base method. +func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. +func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) +} + +// GetClusterSupportsCustomPorts mocks base method. +func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. +func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) +} + // GetCustomDomain mocks base method. func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) { m.ctrl.T.Helper() @@ -1273,6 +1417,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID) } +// GetCustomDomainsCounts mocks base method. +func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts. +func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx) +} + // GetDNSRecordByID mocks base method. func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) { m.ctrl.T.Helper() @@ -1288,6 +1448,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID) } +// GetExpiredEphemeralServices mocks base method. +func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices. +func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit) +} + // GetGroupByID mocks base method. func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) { m.ctrl.T.Helper() @@ -1304,18 +1479,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou } // GetGroupByName mocks base method. -func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) { +func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName) ret0, _ := ret[0].(*types2.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName) } // GetGroupsByIDs mocks base method. @@ -1812,26 +1987,41 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID) } -// GetServiceByDomain mocks base method. -func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { +// GetRoutingPeerNetworks mocks base method. +func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) - ret0, _ := ret[0].(*reverseproxy.Service) + ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. +func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) +} + +// GetServiceByDomain mocks base method. +func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } // GetServiceByDomain indicates an expected call of GetServiceByDomain. -func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain) } // GetServiceByID mocks base method. -func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID) - ret0, _ := ret[0].(*reverseproxy.Service) + ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1843,10 +2033,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se } // GetServiceTargetByTargetID mocks base method. -func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) { +func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID) - ret0, _ := ret[0].(*reverseproxy.Target) + ret0, _ := ret[0].(*service.Target) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1858,10 +2048,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a } // GetServices mocks base method. -func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { +func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength) - ret0, _ := ret[0].([]*reverseproxy.Service) + ret0, _ := ret[0].([]*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1872,6 +2062,36 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength) } +// GetServicesByCluster mocks base method. +func (m *MockStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByCluster", ctx, lockStrength, proxyCluster) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByCluster indicates an expected call of GetServicesByCluster. +func (mr *MockStoreMockRecorder) GetServicesByCluster(ctx, lockStrength, proxyCluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByCluster", reflect.TypeOf((*MockStore)(nil).GetServicesByCluster), ctx, lockStrength, proxyCluster) +} + +// GetServicesByClusterAndPort mocks base method. +func (m *MockStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster, mode string, listenPort uint16) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByClusterAndPort", ctx, lockStrength, proxyCluster, mode, listenPort) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByClusterAndPort indicates an expected call of GetServicesByClusterAndPort. +func (mr *MockStoreMockRecorder) GetServicesByClusterAndPort(ctx, lockStrength, proxyCluster, mode, listenPort interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByClusterAndPort", reflect.TypeOf((*MockStore)(nil).GetServicesByClusterAndPort), ctx, lockStrength, proxyCluster, mode, listenPort) +} + // GetSetupKeyByID mocks base method. func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) { m.ctrl.T.Helper() @@ -1931,6 +2151,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId) } +// GetTargetsByServiceID mocks base method. +func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID) + ret0, _ := ret[0].([]*service.Target) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID. +func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID) +} + // GetTokenIDByHashedToken mocks base method. func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) { m.ctrl.T.Helper() @@ -2312,6 +2547,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID) } +// RenewEphemeralService mocks base method. +func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, serviceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewEphemeralService indicates an expected call of RenewEphemeralService. +func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, serviceID) +} + // RevokeProxyAccessToken mocks base method. func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { m.ctrl.T.Helper() @@ -2536,6 +2785,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck) } +// SaveProxy mocks base method. +func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveProxy indicates an expected call of SaveProxy. +func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() @@ -2731,8 +2994,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups) } +// UpdateProxyHeartbeat mocks base method. +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) +} + // UpdateService mocks base method. -func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { +func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateService", ctx, service) ret0, _ := ret[0].(error) diff --git a/management/server/telemetry/account_aggregator.go b/management/server/telemetry/account_aggregator.go new file mode 100644 index 000000000..cd0863ed6 --- /dev/null +++ b/management/server/telemetry/account_aggregator.go @@ -0,0 +1,185 @@ +package telemetry + +import ( + "context" + "math" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95 +// without publishing individual account labels +type AccountDurationAggregator struct { + mu sync.RWMutex + accounts map[string]*accountHistogram + meterProvider *sdkmetric.MeterProvider + manualReader *sdkmetric.ManualReader + + FlushInterval time.Duration + MaxAge time.Duration + ctx context.Context +} + +type accountHistogram struct { + histogram metric.Int64Histogram + lastUpdate time.Time +} + +// NewAccountDurationAggregator creates aggregator using OTel histograms +func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator { + manualReader := sdkmetric.NewManualReader( + sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality { + return metricdata.DeltaTemporality + }), + ) + + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithReader(manualReader), + ) + + return &AccountDurationAggregator{ + accounts: make(map[string]*accountHistogram), + meterProvider: meterProvider, + manualReader: manualReader, + FlushInterval: flushInterval, + MaxAge: maxAge, + ctx: ctx, + } +} + +// Record adds a duration for an account using OTel histogram +func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) { + a.mu.Lock() + defer a.mu.Unlock() + + accHist, exists := a.accounts[accountID] + if !exists { + meter := a.meterProvider.Meter("account-aggregator") + histogram, err := meter.Int64Histogram( + "sync_duration_per_account", + metric.WithUnit("milliseconds"), + ) + if err != nil { + return + } + + accHist = &accountHistogram{ + histogram: histogram, + } + a.accounts[accountID] = accHist + } + + accHist.histogram.Record(a.ctx, duration.Milliseconds(), + metric.WithAttributes(attribute.String("account_id", accountID))) + accHist.lastUpdate = time.Now() +} + +// FlushAndGetP95s extracts P95 from each account's histogram +func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 { + a.mu.Lock() + defer a.mu.Unlock() + + var rm metricdata.ResourceMetrics + err := a.manualReader.Collect(a.ctx, &rm) + if err != nil { + return nil + } + + now := time.Now() + p95s := make([]int64, 0, len(a.accounts)) + + for _, scopeMetrics := range rm.ScopeMetrics { + for _, metric := range scopeMetrics.Metrics { + histogramData, ok := metric.Data.(metricdata.Histogram[int64]) + if !ok { + continue + } + + for _, dataPoint := range histogramData.DataPoints { + a.processDataPoint(dataPoint, now, &p95s) + } + } + } + + a.cleanupStaleAccounts(now) + + return p95s +} + +// processDataPoint extracts P95 from a single histogram data point +func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) { + accountID := extractAccountID(dataPoint) + if accountID == "" { + return + } + + if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 { + *p95s = append(*p95s, p95) + } +} + +// cleanupStaleAccounts removes accounts that haven't been updated recently +func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) { + for accountID := range a.accounts { + if a.isStaleAccount(accountID, now) { + delete(a.accounts, accountID) + } + } +} + +// extractAccountID retrieves the account_id from histogram data point attributes +func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string { + for _, attr := range dp.Attributes.ToSlice() { + if attr.Key == "account_id" { + return attr.Value.AsString() + } + } + return "" +} + +// isStaleAccount checks if an account hasn't been updated recently +func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool { + accHist, exists := a.accounts[accountID] + if !exists { + return false + } + return now.Sub(accHist.lastUpdate) > a.MaxAge +} + +// calculateP95FromHistogram computes P95 from OTel histogram data +func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 { + if dp.Count == 0 { + return 0 + } + + targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95)) + if targetCount == 0 { + targetCount = 1 + } + var cumulativeCount uint64 + + for i, bucketCount := range dp.BucketCounts { + cumulativeCount += bucketCount + if cumulativeCount >= targetCount { + if i < len(dp.Bounds) { + return int64(dp.Bounds[i]) + } + if maxVal, defined := dp.Max.Value(); defined { + return maxVal + } + return dp.Sum / int64(dp.Count) + } + } + + return dp.Sum / int64(dp.Count) +} + +// Shutdown cleans up resources +func (a *AccountDurationAggregator) Shutdown() error { + return a.meterProvider.Shutdown(a.ctx) +} diff --git a/management/server/telemetry/account_aggregator_test.go b/management/server/telemetry/account_aggregator_test.go new file mode 100644 index 000000000..63b74b1db --- /dev/null +++ b/management/server/telemetry/account_aggregator_test.go @@ -0,0 +1,219 @@ +package telemetry + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) { + // Verify that with delta temporality, each flush window only reflects + // recordings since the last flush — not all-time data. + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + // Window 1: Record 100 slow requests (500ms each) + for range 100 { + agg.Record("account-A", 500*time.Millisecond) + } + + p95sWindow1 := agg.FlushAndGetP95s() + require.Len(t, p95sWindow1, 1, "should have P95 for one account") + firstP95 := p95sWindow1[0] + assert.GreaterOrEqual(t, firstP95, int64(200), + "first window P95 should reflect the 500ms recordings") + + // Window 2: Record 100 FAST requests (10ms each) + for range 100 { + agg.Record("account-A", 10*time.Millisecond) + } + + p95sWindow2 := agg.FlushAndGetP95s() + require.Len(t, p95sWindow2, 1, "should have P95 for one account") + secondP95 := p95sWindow2[0] + + // With delta temporality the P95 should drop significantly because + // the first window's slow recordings are no longer included. + assert.Less(t, secondP95, firstP95, + "second window P95 should be lower than first — delta temporality "+ + "ensures each window only reflects recent recordings") +} + +func TestEqualWeightPerAccount(t *testing.T) { + // Verify that each account contributes exactly one P95 value, + // regardless of how many requests it made. + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + // Account A: 10,000 requests at 500ms (noisy customer) + for range 10000 { + agg.Record("account-A", 500*time.Millisecond) + } + + // Accounts B, C, D: 10 requests each at 50ms (normal customers) + for _, id := range []string{"account-B", "account-C", "account-D"} { + for range 10 { + agg.Record(id, 50*time.Millisecond) + } + } + + p95s := agg.FlushAndGetP95s() + + // Should get exactly 4 P95 values — one per account + assert.Len(t, p95s, 4, "each account should contribute exactly one P95") +} + +func TestStaleAccountEviction(t *testing.T) { + ctx := context.Background() + // Use a very short MaxAge so we can test staleness + agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + agg.Record("account-A", 100*time.Millisecond) + agg.Record("account-B", 200*time.Millisecond) + + // Both accounts should appear + p95s := agg.FlushAndGetP95s() + assert.Len(t, p95s, 2, "both accounts should have P95 values") + + // Wait for account-A to become stale, then only update account-B + time.Sleep(60 * time.Millisecond) + agg.Record("account-B", 200*time.Millisecond) + + p95s = agg.FlushAndGetP95s() + assert.Len(t, p95s, 1, "both accounts should have P95 values") + + // account-A should have been evicted from the accounts map + agg.mu.RLock() + _, accountAExists := agg.accounts["account-A"] + _, accountBExists := agg.accounts["account-B"] + agg.mu.RUnlock() + + assert.False(t, accountAExists, "stale account-A should be evicted from map") + assert.True(t, accountBExists, "active account-B should remain in map") +} + +func TestStaleAccountEviction_DoesNotReappear(t *testing.T) { + // Verify that with delta temporality, an evicted stale account does not + // reappear in subsequent flushes. + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + agg.Record("account-stale", 100*time.Millisecond) + + // Wait for it to become stale + time.Sleep(60 * time.Millisecond) + + // First flush: should detect staleness and evict + _ = agg.FlushAndGetP95s() + + agg.mu.RLock() + _, exists := agg.accounts["account-stale"] + agg.mu.RUnlock() + assert.False(t, exists, "account should be evicted after first flush") + + // Second flush: with delta temporality, the stale account should NOT reappear + p95sSecond := agg.FlushAndGetP95s() + assert.Empty(t, p95sSecond, + "evicted account should not reappear in subsequent flushes with delta temporality") +} + +func TestP95Calculation_SingleSample(t *testing.T) { + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + agg.Record("account-A", 150*time.Millisecond) + + p95s := agg.FlushAndGetP95s() + require.Len(t, p95s, 1) + // With a single sample, P95 should be the bucket bound containing 150ms + assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive") +} + +func TestP95Calculation_AllSameValue(t *testing.T) { + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + // All samples are 100ms — P95 should be the bucket bound containing 100ms + for range 100 { + agg.Record("account-A", 100*time.Millisecond) + } + + p95s := agg.FlushAndGetP95s() + require.Len(t, p95s, 1) + assert.Greater(t, p95s[0], int64(0)) +} + +func TestMultipleAccounts_IndependentP95s(t *testing.T) { + ctx := context.Background() + agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute) + defer func(agg *AccountDurationAggregator) { + err := agg.Shutdown() + if err != nil { + t.Errorf("failed to shutdown aggregator: %v", err) + } + }(agg) + + // Account A: all fast (10ms) + for range 100 { + agg.Record("account-fast", 10*time.Millisecond) + } + + // Account B: all slow (5000ms) + for range 100 { + agg.Record("account-slow", 5000*time.Millisecond) + } + + p95s := agg.FlushAndGetP95s() + require.Len(t, p95s, 2, "should have two P95 values") + + // Find min and max — they should differ significantly + minP95 := p95s[0] + maxP95 := p95s[1] + if minP95 > maxP95 { + minP95, maxP95 = maxP95, minP95 + } + + assert.Less(t, minP95, int64(1000), + "fast account P95 should be well under 1000ms") + assert.Greater(t, maxP95, int64(1000), + "slow account P95 should be well over 1000ms") +} diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index bd7fbc235..d3239c57a 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -13,18 +13,24 @@ const HighLatencyThreshold = time.Second * 7 // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { - meter metric.Meter - syncRequestsCounter metric.Int64Counter - syncRequestsBlockedCounter metric.Int64Counter - loginRequestsCounter metric.Int64Counter - loginRequestsBlockedCounter metric.Int64Counter - loginRequestHighLatencyCounter metric.Int64Counter - getKeyRequestsCounter metric.Int64Counter - activeStreamsGauge metric.Int64ObservableGauge - syncRequestDuration metric.Int64Histogram - loginRequestDuration metric.Int64Histogram - channelQueueLength metric.Int64Histogram - ctx context.Context + meter metric.Meter + syncRequestsCounter metric.Int64Counter + syncRequestsBlockedCounter metric.Int64Counter + loginRequestsCounter metric.Int64Counter + loginRequestsBlockedCounter metric.Int64Counter + loginRequestHighLatencyCounter metric.Int64Counter + getKeyRequestsCounter metric.Int64Counter + activeStreamsGauge metric.Int64ObservableGauge + syncRequestDuration metric.Int64Histogram + syncRequestDurationP95ByAccount metric.Int64Histogram + loginRequestDuration metric.Int64Histogram + loginRequestDurationP95ByAccount metric.Int64Histogram + channelQueueLength metric.Int64Histogram + ctx context.Context + + // Per-account aggregation + syncDurationAggregator *AccountDurationAggregator + loginDurationAggregator *AccountDurationAggregator } // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server @@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"), + ) + if err != nil { + return nil, err + } + loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"), + ) + if err != nil { + return nil, err + } + // We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time // Then we should be able to extract min, manx, mean and the percentiles. // TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100) @@ -113,20 +135,32 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - return &GRPCMetrics{ - meter: meter, - syncRequestsCounter: syncRequestsCounter, - syncRequestsBlockedCounter: syncRequestsBlockedCounter, - loginRequestsCounter: loginRequestsCounter, - loginRequestsBlockedCounter: loginRequestsBlockedCounter, - loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, - getKeyRequestsCounter: getKeyRequestsCounter, - activeStreamsGauge: activeStreamsGauge, - syncRequestDuration: syncRequestDuration, - loginRequestDuration: loginRequestDuration, - channelQueueLength: channelQueue, - ctx: ctx, - }, err + syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute) + loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute) + + grpcMetrics := &GRPCMetrics{ + meter: meter, + syncRequestsCounter: syncRequestsCounter, + syncRequestsBlockedCounter: syncRequestsBlockedCounter, + loginRequestsCounter: loginRequestsCounter, + loginRequestsBlockedCounter: loginRequestsBlockedCounter, + loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, + getKeyRequestsCounter: getKeyRequestsCounter, + activeStreamsGauge: activeStreamsGauge, + syncRequestDuration: syncRequestDuration, + syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount, + loginRequestDuration: loginRequestDuration, + loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount, + channelQueueLength: channelQueue, + ctx: ctx, + syncDurationAggregator: syncDurationAggregator, + loginDurationAggregator: loginDurationAggregator, + } + + go grpcMetrics.startSyncP95Flusher() + go grpcMetrics.startLoginP95Flusher() + + return grpcMetrics, err } // CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API @@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() { // CountLoginRequestDuration counts the duration of the login gRPC requests func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + + grpcMetrics.loginDurationAggregator.Record(accountID, duration) + if duration > HighLatencyThreshold { grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) } @@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration // CountSyncRequestDuration counts the duration of the sync gRPC requests func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + + grpcMetrics.syncDurationAggregator.Record(accountID, duration) +} + +// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram +func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() { + ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval) + defer ticker.Stop() + + for { + select { + case <-grpcMetrics.ctx.Done(): + return + case <-ticker.C: + p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s() + for _, p95 := range p95s { + grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95) + } + } + } +} + +// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram +func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() { + ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval) + defer ticker.Stop() + + for { + select { + case <-grpcMetrics.ctx.Done(): + return + case <-ticker.C: + p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s() + for _, p95 := range p95s { + grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95) + } + } + } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index c50ed1e51..28e8457e2 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { w := WrapResponseWriter(rw) + handlerDone := make(chan struct{}) + context.AfterFunc(ctx, func() { + select { + case <-handlerDone: + default: + log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)", + r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx)) + } + }) + h.ServeHTTP(w, r.WithContext(ctx)) + close(handlerDone) userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) if err == nil { diff --git a/management/server/types/account.go b/management/server/types/account.go index 3208cc89a..c448813db 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -18,7 +18,8 @@ import ( "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -100,7 +101,8 @@ type Account struct { NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"` + Services []*service.Service `gorm:"foreignKey:AccountID;references:id"` + Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` @@ -906,9 +908,14 @@ func (a *Account) Copy() *Account { networkResources = append(networkResources, resource.Copy()) } - services := []*reverseproxy.Service{} - for _, service := range a.Services { - services = append(services, service.Copy()) + services := []*service.Service{} + for _, svc := range a.Services { + services = append(services, svc.Copy()) + } + + domains := []*proxydomain.Domain{} + for _, domain := range a.Domains { + domains = append(domains, domain.Copy()) } return &Account{ @@ -936,6 +943,7 @@ func (a *Account) Copy() *Account { Onboarding: a.Onboarding, NetworkMapCache: a.NetworkMapCache, nmapInitOnce: a.nmapInitOnce, + Domains: domains, } } @@ -1605,12 +1613,12 @@ func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { networkResourceGroups := a.getNetworkResourceGroups(resourceId) for _, policy := range a.Policies { - if !policy.Enabled { + if policy == nil || !policy.Enabled { continue } for _, rule := range policy.Rules { - if !rule.Enabled { + if rule == nil || !rule.Enabled { continue } @@ -1812,18 +1820,21 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) { } a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster) } + } -func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { +func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { + proxyPeers := proxyPeersByCluster[service.ProxyCluster] for _, target := range service.Targets { if !target.Enabled { continue } - a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster]) + a.injectTargetProxyPolicies(ctx, service, target, proxyPeers) } + } -func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) { +func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { port, ok := a.resolveTargetPort(ctx, target) if !ok { return @@ -1840,13 +1851,13 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers } } -func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) { +func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (uint16, bool) { if target.Port != 0 { return target.Port, true } switch target.Protocol { - case "https": + case "https", "tls": return 443, true case "http": return 80, true @@ -1856,17 +1867,23 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta } } -func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { - policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) +func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port uint16, path string) *Policy { + policyID := fmt.Sprintf("proxy-access-%s-%s-%s", svc.ID, proxyPeer.ID, path) + + protocol := PolicyRuleProtocolTCP + if svc.Mode == service.ModeUDP { + protocol = PolicyRuleProtocolUDP + } + return &Policy{ ID: policyID, - Name: fmt.Sprintf("Proxy Access to %s", service.Name), + Name: fmt.Sprintf("Proxy Access to %s", svc.Name), Enabled: true, Rules: []*PolicyRule{ { ID: policyID, PolicyID: policyID, - Name: fmt.Sprintf("Allow access to %s", service.Name), + Name: fmt.Sprintf("Allow access to %s", svc.Name), Enabled: true, SourceResource: Resource{ ID: proxyPeer.ID, @@ -1877,12 +1894,12 @@ func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *rever Type: ResourceType(target.TargetType), }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: protocol, Action: PolicyTrafficActionAccept, PortRanges: []RulePortRange{ { - Start: uint16(port), - End: uint16(port), + Start: port, + End: port, }, }, }, diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go new file mode 100644 index 000000000..bd4244546 --- /dev/null +++ b/management/server/types/account_components.go @@ -0,0 +1,576 @@ +package types + +import ( + "context" + "slices" + "time" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +func (a *Account) GetPeerNetworkMapFromComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, + groupIDToUserIDs map[string][]string, +) *NetworkMap { + start := time.Now() + + components := a.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + accountZones, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + return &NetworkMap{Network: a.Network.Copy()} + } + + nm := CalculateNetworkMapFromComponents(ctx, components) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from components, "+ + "peers: %d, offline peers: %d, routes: %d, firewall rules: %d, route firewall rules: %d", + a.Id, objectCount, len(nm.Peers), len(nm.OfflinePeers), len(nm.Routes), len(nm.FirewallRules), len(nm.RoutesFirewallRules)) + } + } + + return nm +} + +func (a *Account) GetPeerNetworkMapComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + groupIDToUserIDs map[string][]string, +) *NetworkMapComponents { + + peer := a.Peers[peerID] + if peer == nil { + return nil + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return nil + } + + components := &NetworkMapComponents{ + PeerID: peerID, + Network: a.Network.Copy(), + NameServerGroups: make([]*nbdns.NameServerGroup, 0), + CustomZoneDomain: peersCustomZone.Domain, + ResourcePoliciesMap: make(map[string][]*Policy), + RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), + NetworkResources: make([]*resourceTypes.NetworkResource, 0), + PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)), + RouterPeers: make(map[string]*nbpeer.Peer), + } + + components.AccountSettings = &AccountSettingsInfo{ + PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: a.Settings.PeerLoginExpiration, + PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: a.Settings.PeerInactivityExpiration, + } + + components.DNSSettings = &a.DNSSettings + + relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers) + + if len(sshReqs.neededGroupIDs) > 0 { + components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs) + } + if sshReqs.needAllowedUserIDs { + components.AllowedUserIDs = a.getAllowedUserIDs() + } + + components.Peers = relevantPeers + components.Groups = relevantGroups + components.Policies = relevantPolicies + components.Routes = relevantRoutes + components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers) + + peerGroups := a.GetPeerGroups(peerID) + components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) + + for _, nsGroup := range a.NameServerGroups { + if nsGroup.Enabled { + for _, gID := range nsGroup.Groups { + if _, found := relevantGroups[gID]; found { + components.NameServerGroups = append(components.NameServerGroups, nsGroup) + break + } + } + } + } + + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + policies, exists := resourcePolicies[resource.ID] + if !exists { + continue + } + + addSourcePeers := false + + networkRoutingPeers, routerExists := routers[resource.NetworkID] + if routerExists { + if _, ok := networkRoutingPeers[peerID]; ok { + addSourcePeers = true + } + } + + for _, policy := range policies { + if addSourcePeers { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, validatedPeersMap, &components.PostureFailedPeers) { + if _, exists := components.Peers[pID]; !exists { + components.Peers[pID] = a.GetPeer(pID) + } + } + } else { + peerInSources := false + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peerInSources = policy.Rules[0].SourceResource.ID == peerID + } else { + for _, groupID := range policy.SourceGroups() { + if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) { + peerInSources = true + break + } + } + } + if !peerInSources { + continue + } + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID) + if !isValid && len(pname) > 0 { + if _, ok := components.PostureFailedPeers[pname]; !ok { + components.PostureFailedPeers[pname] = make(map[string]struct{}) + } + components.PostureFailedPeers[pname][peer.ID] = struct{}{} + continue + } + addSourcePeers = true + } + + for _, rule := range policy.Rules { + for _, srcGroupID := range rule.Sources { + if g := a.Groups[srcGroupID]; g != nil { + if _, exists := components.Groups[srcGroupID]; !exists { + components.Groups[srcGroupID] = g + } + } + } + for _, dstGroupID := range rule.Destinations { + if g := a.Groups[dstGroupID]; g != nil { + if _, exists := components.Groups[dstGroupID]; !exists { + components.Groups[dstGroupID] = g + } + } + } + } + components.ResourcePoliciesMap[resource.ID] = policies + } + + components.RoutersMap[resource.NetworkID] = networkRoutingPeers + for peerIDKey := range networkRoutingPeers { + if p := a.Peers[peerIDKey]; p != nil { + if _, exists := components.RouterPeers[peerIDKey]; !exists { + components.RouterPeers[peerIDKey] = p + } + if _, exists := components.Peers[peerIDKey]; !exists { + if _, validated := validatedPeersMap[peerIDKey]; validated { + components.Peers[peerIDKey] = p + } + } + } + } + + if addSourcePeers { + components.NetworkResources = append(components.NetworkResources, resource) + } + } + + filterGroupPeers(&components.Groups, components.Peers) + filterPostureFailedPeers(&components.PostureFailedPeers, components.Policies, components.ResourcePoliciesMap, components.Peers) + + return components +} + +type sshRequirements struct { + neededGroupIDs map[string]struct{} + needAllowedUserIDs bool +} + +func (a *Account) getPeersGroupsPoliciesRoutes( + ctx context.Context, + peerID string, + peerSSHEnabled bool, + validatedPeersMap map[string]struct{}, + postureFailedPeers *map[string]map[string]struct{}, +) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) { + relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4) + relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4) + relevantPolicies := make([]*Policy, 0, len(a.Policies)) + relevantRoutes := make([]*route.Route, 0, len(a.Routes)) + sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})} + + relevantPeerIDs[peerID] = a.GetPeer(peerID) + + for groupID, group := range a.Groups { + if slices.Contains(group.Peers, peerID) { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + } + + routeAccessControlGroups := make(map[string]struct{}) + for _, r := range a.Routes { + for _, groupID := range r.Groups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + for _, groupID := range r.PeerGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + if r.Enabled { + for _, groupID := range r.AccessControlGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + routeAccessControlGroups[groupID] = struct{}{} + } + } + relevantRoutes = append(relevantRoutes, r) + } + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + policyRelevant := false + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if len(routeAccessControlGroups) > 0 { + for _, destGroupID := range rule.Destinations { + if _, needed := routeAccessControlGroups[destGroupID]; needed { + policyRelevant = true + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + break + } + } + } + + var sourcePeers, destinationPeers []string + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers = []string{rule.SourceResource.ID} + if rule.SourceResource.ID == peerID { + peerInSources = true + } + } else { + sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers = []string{rule.DestinationResource.ID} + if rule.DestinationResource.ID == peerID { + peerInDestinations = true + } + } else { + destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers) + } + + if peerInSources { + policyRelevant = true + for _, pid := range destinationPeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + } + + if peerInDestinations { + policyRelevant = true + for _, pid := range sourcePeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + + if rule.Protocol == PolicyRuleProtocolNetbirdSSH { + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID := range rule.AuthorizedGroups { + sshReqs.neededGroupIDs[groupID] = struct{}{} + } + case rule.AuthorizedUser != "": + default: + sshReqs.needAllowedUserIDs = true + } + } else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled { + sshReqs.needAllowedUserIDs = true + } + } + } + if policyRelevant { + relevantPolicies = append(relevantPolicies, policy) + } + } + + return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs +} + +func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, + validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { + peerInGroups := false + filteredPeerIDs := make([]string, 0, len(groups)) + seenPeerIds := make(map[string]struct{}, len(groups)) + + for _, gid := range groups { + group := a.GetGroup(gid) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + filteredPeerIDs = make([]string, 0, len(group.Peers)) + peerInGroups = false + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + return filteredPeerIDs, peerInGroups + } + + for _, pid := range group.Peers { + if _, seen := seenPeerIds[pid]; seen { + continue + } + seenPeerIds[pid] = struct{}{} + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + } + + return filteredPeerIDs, peerInGroups +} + +func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) { + peer, ok := a.Peers[peerID] + if !ok || peer == nil { + return false, "" + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, _ := check.Check(ctx, *peer) + if !isValid { + return false, postureChecksID + } + } + } + return true, "" +} + +func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) []string { + var dest []string + for _, peerID := range inputPeers { + if _, validated := validatedPeersMap[peerID]; !validated { + continue + } + valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID) + if valid { + dest = append(dest, peerID) + continue + } + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peerID] = struct{}{} + } + return dest +} + +func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) { + for groupID, groupInfo := range *groups { + filteredPeers := make([]string, 0, len(groupInfo.Peers)) + for _, pid := range groupInfo.Peers { + if _, exists := peers[pid]; exists { + filteredPeers = append(filteredPeers, pid) + } + } + + if len(filteredPeers) == 0 { + delete(*groups, groupID) + } else if len(filteredPeers) != len(groupInfo.Peers) { + ng := groupInfo.Copy() + ng.Peers = filteredPeers + (*groups)[groupID] = ng + } + } +} + +func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{}, policies []*Policy, resourcePoliciesMap map[string][]*Policy, peers map[string]*nbpeer.Peer) { + if len(*postureFailedPeers) == 0 { + return + } + + referencedPostureChecks := make(map[string]struct{}) + for _, policy := range policies { + for _, checkID := range policy.SourcePostureChecks { + referencedPostureChecks[checkID] = struct{}{} + } + } + for _, resPolicies := range resourcePoliciesMap { + for _, policy := range resPolicies { + for _, checkID := range policy.SourcePostureChecks { + referencedPostureChecks[checkID] = struct{}{} + } + } + } + + for checkID, failedPeers := range *postureFailedPeers { + if _, referenced := referencedPostureChecks[checkID]; !referenced { + delete(*postureFailedPeers, checkID) + continue + } + for peerID := range failedPeers { + if _, exists := peers[peerID]; !exists { + delete(failedPeers, peerID) + } + } + if len(failedPeers) == 0 { + delete(*postureFailedPeers, checkID) + } + } +} + +func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord { + if len(records) == 0 || len(peers) == 0 { + return nil + } + + peerIPs := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + if peer != nil { + peerIPs[peer.IP.String()] = struct{}{} + } + } + + filteredRecords := make([]nbdns.SimpleRecord, 0, len(records)) + for _, record := range records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} + +func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string { + if len(neededGroupIDs) == 0 { + return nil + } + + filtered := make(map[string][]string, len(neededGroupIDs)) + for groupID := range neededGroupIDs { + if users, ok := fullMap[groupID]; ok { + filtered[groupID] = users + } + } + return filtered +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index af2896216..00ba29b7f 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -84,6 +84,12 @@ func setupTestAccount() *Account { }, }, Groups: map[string]*Group{ + "groupAll": { + ID: "groupAll", + Name: "All", + Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, + Issued: GroupIssuedAPI, + }, "group1": { ID: "group1", Peers: []string{"peer11", "peer12"}, diff --git a/management/server/types/network.go b/management/server/types/network.go index d3708d80a..0d13de10f 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -152,6 +152,8 @@ func (n *Network) CurrentSerial() uint64 { } func (n *Network) Copy() *Network { + n.Mu.Lock() + defer n.Mu.Unlock() return &Network{ Identifier: n.Identifier, Net: n.Net, diff --git a/management/server/types/networkmap_benchmark_test.go b/management/server/types/networkmap_benchmark_test.go new file mode 100644 index 000000000..38272e7b0 --- /dev/null +++ b/management/server/types/networkmap_benchmark_test.go @@ -0,0 +1,217 @@ +package types_test + +import ( + "context" + "fmt" + "os" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/types" +) + +type benchmarkScale struct { + name string + peers int + groups int +} + +var defaultScales = []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + {"10000peers_200groups", 10000, 200}, + {"20000peers_200groups", 20000, 200}, + {"30000peers_300groups", 30000, 300}, +} + +func skipCIBenchmark(b *testing.B) { + if os.Getenv("CI") == "true" { + b.Skip("Skipping benchmark in CI") + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Single Peer Network Map Generation +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer. +func BenchmarkNetworkMapGeneration_Components(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// All Peers (UpdateAccountPeers hot path) +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers. +func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) { + skipCIBenchmark(b) + scales := []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + } + + for _, scale := range scales { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + + peerIDs := make([]string, 0, len(account.Peers)) + for peerID := range account.Peers { + peerIDs = append(peerIDs, peerID) + } + + b.Run("components/"+scale.name, func(b *testing.B) { + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Sub-operations +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction. +func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components. +func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = types.CalculateNetworkMapFromComponents(ctx, components) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs. +func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourcePoliciesMap() + } + }) + b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourceRoutersMap() + } + }) + b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetActiveGroupUsers() + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Scaling Analysis +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance. +func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) { + skipCIBenchmark(b) + groupCounts := []int{1, 5, 20, 50, 100, 200, 500} + for _, numGroups := range groupCounts { + b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(1000, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance. +func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) { + skipCIBenchmark(b) + peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000} + for _, numPeers := range peerCounts { + numGroups := numPeers / 20 + if numGroups < 1 { + numGroups = 1 + } + b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(numPeers, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go new file mode 100644 index 000000000..c5844cca0 --- /dev/null +++ b/management/server/types/networkmap_comparison_test.go @@ -0,0 +1,592 @@ +package types + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" +) + +func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + t.Fatal("GetPeerNetworkMapComponents returned nil") + } + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + + if newNetworkMap == nil { + t.Fatal("CalculateNetworkMapFromComponents returned nil") + } + + compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) +} + +func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + goldenDir := filepath.Join("testdata", "comparison") + err = os.MkdirAll(goldenDir, 0755) + require.NoError(t, err) + + legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") + err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) + require.NoError(t, err, "error writing legacy golden file") + + newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") + err = os.WriteFile(newGoldenPath, newJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + componentsPath := filepath.Join(goldenDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s", + legacyGoldenPath, newGoldenPath) + + t.Logf("✅ NetworkMaps are identical") + t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) + t.Logf(" Components NetworkMap: %s", newGoldenPath) +} + +func normalizeAndSortNetworkMap(nm *NetworkMap) { + if nm == nil { + return + } + + sort.Slice(nm.Peers, func(i, j int) bool { + return nm.Peers[i].ID < nm.Peers[j].ID + }) + + sort.Slice(nm.OfflinePeers, func(i, j int) bool { + return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID + }) + + sort.Slice(nm.Routes, func(i, j int) bool { + return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) + }) + + sort.Slice(nm.FirewallRules, func(i, j int) bool { + if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { + return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP + } + if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { + return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction + } + if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { + return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol + } + if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { + return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port + } + return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID + }) + + for i := range nm.RoutesFirewallRules { + sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) + } + + sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { + if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { + return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination + } + + minLen := len(nm.RoutesFirewallRules[i].SourceRanges) + if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { + minLen = len(nm.RoutesFirewallRules[j].SourceRanges) + } + for k := 0; k < minLen; k++ { + if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { + return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] + } + } + if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { + return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) + } + + if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { + return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) + } + + if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { + return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID + } + + if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { + return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port + } + + return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol + }) + + if nm.DNSConfig.CustomZones != nil { + for i := range nm.DNSConfig.CustomZones { + sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { + return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name + }) + } + } + + if len(nm.DNSConfig.NameServerGroups) != 0 { + sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { + return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name + }) + } +} + +func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) { + t.Helper() + + if legacy.Network.Serial != current.Network.Serial { + t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial) + } + + if len(legacy.Peers) != len(current.Peers) { + t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers)) + } + + legacyPeerIDs := make(map[string]bool) + for _, p := range legacy.Peers { + legacyPeerIDs[p.ID] = true + } + + for _, p := range current.Peers { + if !legacyPeerIDs[p.ID] { + t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID) + } + } + + if len(legacy.OfflinePeers) != len(current.OfflinePeers) { + t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers)) + } + + if len(legacy.FirewallRules) != len(current.FirewallRules) { + t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules)) + } + + if len(legacy.Routes) != len(current.Routes) { + t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes)) + } + + if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) { + t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules)) + } + + if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable { + t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable) + } +} + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" + expiredPeerID = "peer-98" + offlinePeerID = "peer-99" + routingPeerID = "peer-95" + testAccountID = "account-comparison-test" +) + +func createTestAccount() *Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + } + + groups := map[string]*Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, + Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &Network{ + Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} + +func BenchmarkLegacyNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + } +} + +func BenchmarkComponentsNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func BenchmarkComponentsCreation(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + } +} + +func BenchmarkCalculationFromComponents(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go new file mode 100644 index 000000000..23d84a994 --- /dev/null +++ b/management/server/types/networkmap_components.go @@ -0,0 +1,901 @@ +package types + +import ( + "context" + "maps" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/netbirdio/netbird/client/ssh/auth" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" + +type NetworkMapComponents struct { + PeerID string + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + Peers map[string]*nbpeer.Peer + Groups map[string]*Group + Policies []*Policy + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + AccountZones []nbdns.CustomZone + ResourcePoliciesMap map[string][]*Policy + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} + PostureFailedPeers map[string]map[string]struct{} + + RouterPeers map[string]*nbpeer.Peer +} + +type AccountSettingsInfo struct { + PeerLoginExpirationEnabled bool + PeerLoginExpiration time.Duration + PeerInactivityExpirationEnabled bool + PeerInactivityExpiration time.Duration +} + +func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer { + return c.Peers[peerID] +} + +func (c *NetworkMapComponents) GetRouterPeerInfo(peerID string) *nbpeer.Peer { + return c.RouterPeers[peerID] +} + +func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group { + return c.Groups[groupID] +} + +func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool { + group := c.GetGroupInfo(groupID) + if group == nil { + return false + } + + return slices.Contains(group.Peers, peerID) +} + +func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} { + groups := make(map[string]struct{}) + for groupID, group := range c.Groups { + if slices.Contains(group.Peers, peerID) { + groups[groupID] = struct{}{} + } + } + return groups +} + +func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool { + _, exists := c.Peers[peerID] + if !exists { + return false + } + if len(postureCheckIDs) == 0 { + return true + } + for _, checkID := range postureCheckIDs { + if failedPeers, exists := c.PostureFailedPeers[checkID]; exists { + if _, failed := failedPeers[peerID]; failed { + return false + } + } + } + return true +} + +func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap { + return components.Calculate(ctx) +} + +func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { + targetPeerID := c.PeerID + + peerGroups := c.GetPeerGroups(targetPeerID) + + aclPeers, firewallRules, authorizedUsers, sshEnabled := c.getPeerConnectionResources(targetPeerID) + + peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers) + + routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups) + routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID) + + isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + } + + peersToConnectIncludingRouters := c.addNetworksRoutingPeers( + networkResourcesRoutes, + targetPeerID, + peersToConnect, + expiredPeers, + isRouter, + sourcePeers, + ) + + dnsManagementStatus := c.getPeerDNSManagementStatusFromGroups(peerGroups) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var customZones []nbdns.CustomZone + + if c.CustomZoneDomain != "" && len(c.AllDNSRecords) > 0 { + customZones = append(customZones, nbdns.CustomZone{ + Domain: c.CustomZoneDomain, + Records: c.AllDNSRecords, + }) + } + + customZones = append(customZones, c.AccountZones...) + + dnsUpdate.CustomZones = customZones + dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups) + } + + return &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: c.Network.Copy(), + Routes: append(networkResourcesRoutes, routesUpdate...), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...), + AuthorizedUsers: authorizedUsers, + EnableSSH: sshEnabled, + } +} + +func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { + targetPeer := c.GetPeerInfo(targetPeerID) + if targetPeer == nil { + return nil, nil, nil, false + } + + generateResources, getAccumulatedResources := c.connResourcesGenerator(targetPeer) + authorizedUsers := make(map[string]map[string]struct{}) + sshEnabled := false + + for _, policy := range c.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = c.getPeerFromResource(rule.SourceResource, targetPeerID) + } else { + sourcePeers, peerInSources = c.getAllPeersFromGroups(rule.Sources, targetPeerID, policy.SourcePostureChecks) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = c.getPeerFromResource(rule.DestinationResource, targetPeerID) + } else { + destinationPeers, peerInDestinations = c.getAllPeersFromGroups(rule.Destinations, targetPeerID, nil) + } + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := c.GroupIDToUserIDs[groupID] + if !ok { + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() + } + } + } + + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (c *NetworkMapComponents) getAllowedUserIDs() map[string]struct{} { + if c.AllowedUserIDs != nil { + result := make(map[string]struct{}, len(c.AllowedUserIDs)) + maps.Copy(result, c.AllowedUserIDs) + return result + } + return make(map[string]struct{}) +} + +func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + + protocolStr := string(protocol) + actionStr := string(rule.Action) + dirStr := strconv.Itoa(direction) + portsJoined := strings.Join(rule.Ports, ",") + + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + peerIP := net.IP(peer.IP).String() + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peerIP, + Direction: direction, + Action: actionStr, + Protocol: protocolStr, + } + + ruleID := rule.ID + peerIP + dirStr + + protocolStr + actionStr + portsJoined + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + rules = append(rules, &fr) + continue + } + + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +func (c *NetworkMapComponents) getAllPeersFromGroups(groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { + peerInGroups := false + uniquePeerIDs := c.getUniquePeerIDsFromGroupsIDs(groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + + for _, p := range uniquePeerIDs { + peerInfo := c.GetPeerInfo(p) + if peerInfo == nil { + continue + } + + if _, ok := c.Peers[p]; !ok { + continue + } + + if !c.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) { + continue + } + + if p == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peerInfo) + } + + return filteredPeers, peerInGroups +} + +func (c *NetworkMapComponents) getUniquePeerIDsFromGroupsIDs(groups []string) []string { + peerIDs := make(map[string]struct{}, len(groups)) + for _, groupID := range groups { + group := c.GetGroupInfo(groupID) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } + } + + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) + } + + return ids +} + +func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + if resource.ID == peerID { + return []*nbpeer.Peer{}, true + } + + peerInfo := c.GetPeerInfo(resource.ID) + if peerInfo == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peerInfo}, false +} + +func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { + peersToConnect := make([]*nbpeer.Peer, 0, len(aclPeers)) + var expiredPeers []*nbpeer.Peer + + for _, p := range aclPeers { + expired, _ := p.LoginExpired(c.AccountSettings.PeerLoginExpiration) + if c.AccountSettings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + return peersToConnect, expiredPeers +} + +func (c *NetworkMapComponents) getPeerDNSManagementStatusFromGroups(peerGroups map[string]struct{}) bool { + for _, groupID := range c.DNSSettings.DisabledManagementGroups { + if _, found := peerGroups[groupID]; found { + return false + } + } + return true +} + +func (c *NetworkMapComponents) getPeerNSGroupsFromGroups(peerID string, groupList map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + targetPeerInfo := c.GetPeerInfo(peerID) + if targetPeerInfo == nil { + return peerNSGroups + } + + peerIPStr := targetPeerInfo.IP.String() + + for _, nsGroup := range c.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + if _, found := groupList[gID]; found { + if !c.peerIsNameserver(peerIPStr, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + } + break + } + } + } + + return peerNSGroups +} + +func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peerIPStr == ns.IP.String() { + return true + } + } + return false +} + +func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { + routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + for _, peer := range aclPeers { + activeRoutes, _ := c.getRoutingPeerRoutes(peer.ID) + groupFilteredRoutes := c.filterRoutesByGroups(activeRoutes, peerGroups) + filteredRoutes := c.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peerInfo := c.GetPeerInfo(peerID) + if peerInfo == nil { + peerInfo = c.GetRouterPeerInfo(peerID) + } + if peerInfo == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + r.Peer = peerInfo.Key + + if r.Enabled { + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + for _, r := range c.Routes { + for _, groupID := range r.PeerGroups { + group := c.GetGroupInfo(groupID) + if group == nil { + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) + takeRoute(newPeerRoute) + break + } + } + if r.Peer == peerID { + takeRoute(r.Copy()) + } + } + + return enabledRoutes, disabledRoutes +} + + +func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := c.getRoutingPeerRoutes(peerID) + for _, r := range enabledRoutes { + if len(r.AccessControlGroups) == 0 { + defaultPermit := c.getDefaultPermit(r) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := c.getDistributionGroupsPeers(r) + + for _, accessGroup := range r.AccessControlGroups { + policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup}) + rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if r.Network.Addr().Is6() { + sources = []string{"::/0"} + } + + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: r.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: r.Domains, + IsDynamic: r.IsDynamic(), + RouteID: r.ID, + } + + rules = append(rules, &rule) + + if r.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +func (c *NetworkMapComponents) getDistributionGroupsPeers(r *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range r.Groups { + group := c.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + for _, policy := range c.Policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies = append(routePolicies, policy) + } + } + } + } + + return routePolicies +} + +func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (c *NetworkMapComponents) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := c.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := c.Peers[pID] + if distPeer && valid && c.ValidatePostureChecksOnPeer(pID, postureChecks) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := c.Peers[rule.SourceResource.ID] + if distPeer && valid && c.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peerInfo := c.GetPeerInfo(pID) + if peerInfo == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peerInfo) + } + return distributionGroupPeers +} + +func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (bool, []*route.Route, map[string]struct{}) { + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + for _, resource := range c.NetworkResources { + if !resource.Enabled { + continue + } + + var addSourcePeers bool + + networkRoutingPeers, exists := c.RoutersMap[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerID, router)...) + } + } + + addedResourceRoute := false + for _, policy := range c.ResourcePoliciesMap[resource.ID] { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) + } + if addSourcePeers { + for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + } else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) + } + addedResourceRoute = true + } + if addedResourceRoute { + break + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { + resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID] + + var routes []*route.Route + if len(resourceAppliedPolicies) > 0 { + peerInfo := c.GetPeerInfo(peerID) + if peerInfo != nil { + routes = append(routes, c.networkResourceToRoute(resource, peerInfo, router)) + } + } + + return routes +} + +func (c *NetworkMapComponents) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(resource.ID + ":" + peer.ID), + AccountID: resource.AccountID, + Peer: peer.Key, + PeerID: peer.ID, + Metric: router.Metric, + Masquerade: router.Masquerade, + Enabled: resource.Enabled, + KeepRoute: true, + NetID: route.NetID(resource.Name), + Description: resource.Description, + } + + if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet { + r.Network = resource.Prefix + + r.NetworkType = route.IPv4Network + if resource.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if resource.Type == resourceTypes.Domain { + domainList, err := domain.FromStringList([]string{resource.Domain}) + if err == nil { + r.Domains = domainList + r.NetworkType = route.DomainNetwork + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + } + + return r +} + +func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if c.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) { + dest = append(dest, peerID) + } + } + return dest +} + +func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + peerInfo := c.GetPeerInfo(peerID) + if peerInfo == nil { + return routesFirewallRules + } + + for _, r := range routes { + if r.Peer != peerInfo.Key { + continue + } + + resourceID := string(r.GetResourceID()) + resourcePolicies := c.ResourcePoliciesMap[resourceID] + distributionPeers := c.getPoliciesSourcePeers(resourcePolicies) + + rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + } + + return routesFirewallRules +} + +func (c *NetworkMapComponents) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := c.GetGroupInfo(sourceGroup) + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } + } + } + + return sourcePeers +} + +func (c *NetworkMapComponents) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peerID string, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peerID) + delete(networkRoutesPeers, peerID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + peerInfo := c.GetPeerInfo(p) + if peerInfo == nil { + peerInfo = c.GetRouterPeerInfo(p) + } + if peerInfo != nil { + peersToConnect = append(peersToConnect, peerInfo) + } + } + + return peersToConnect +} diff --git a/management/server/types/networkmap_components_compact.go b/management/server/types/networkmap_components_compact.go new file mode 100644 index 000000000..b60f8bdb1 --- /dev/null +++ b/management/server/types/networkmap_components_compact.go @@ -0,0 +1,230 @@ +package types + +import ( + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" +) + +type GroupCompact struct { + Name string + PeerIndexes []int +} + +type NetworkMapComponentsCompact struct { + PeerID string + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + AllPeers []*nbpeer.Peer + PeerIndexes []int + RouterPeerIndexes []int + + Groups map[string]*GroupCompact + AllPolicies []*Policy + PolicyIndexes []int + ResourcePoliciesMap map[string][]int + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + AccountZones []nbdns.CustomZone + + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} + PostureFailedPeers map[string]map[string]struct{} +} + +func (c *NetworkMapComponents) ToCompact() *NetworkMapComponentsCompact { + peerToIndex := make(map[string]int) + var allPeers []*nbpeer.Peer + + for id, peer := range c.Peers { + if _, exists := peerToIndex[id]; !exists { + peerToIndex[id] = len(allPeers) + allPeers = append(allPeers, peer) + } + } + + for id, peer := range c.RouterPeers { + if _, exists := peerToIndex[id]; !exists { + peerToIndex[id] = len(allPeers) + allPeers = append(allPeers, peer) + } + } + + peerIndexes := make([]int, 0, len(c.Peers)) + for id := range c.Peers { + peerIndexes = append(peerIndexes, peerToIndex[id]) + } + + routerPeerIndexes := make([]int, 0, len(c.RouterPeers)) + for id := range c.RouterPeers { + routerPeerIndexes = append(routerPeerIndexes, peerToIndex[id]) + } + + groups := make(map[string]*GroupCompact, len(c.Groups)) + for id, group := range c.Groups { + peerIdxs := make([]int, 0, len(group.Peers)) + for _, peerID := range group.Peers { + if idx, ok := peerToIndex[peerID]; ok { + peerIdxs = append(peerIdxs, idx) + } + } + groups[id] = &GroupCompact{ + Name: group.Name, + PeerIndexes: peerIdxs, + } + } + + policyToIndex := make(map[*Policy]int) + var allPolicies []*Policy + + for _, policy := range c.Policies { + if _, exists := policyToIndex[policy]; !exists { + policyToIndex[policy] = len(allPolicies) + allPolicies = append(allPolicies, policy) + } + } + + for _, policies := range c.ResourcePoliciesMap { + for _, policy := range policies { + if _, exists := policyToIndex[policy]; !exists { + policyToIndex[policy] = len(allPolicies) + allPolicies = append(allPolicies, policy) + } + } + } + + policyIndexes := make([]int, len(c.Policies)) + for i, policy := range c.Policies { + policyIndexes[i] = policyToIndex[policy] + } + + var resourcePoliciesMap map[string][]int + if len(c.ResourcePoliciesMap) > 0 { + resourcePoliciesMap = make(map[string][]int, len(c.ResourcePoliciesMap)) + for resID, policies := range c.ResourcePoliciesMap { + indexes := make([]int, len(policies)) + for i, policy := range policies { + indexes[i] = policyToIndex[policy] + } + resourcePoliciesMap[resID] = indexes + } + } + + return &NetworkMapComponentsCompact{ + PeerID: c.PeerID, + Network: c.Network, + AccountSettings: c.AccountSettings, + DNSSettings: c.DNSSettings, + CustomZoneDomain: c.CustomZoneDomain, + + AllPeers: allPeers, + PeerIndexes: peerIndexes, + RouterPeerIndexes: routerPeerIndexes, + + Groups: groups, + AllPolicies: allPolicies, + PolicyIndexes: policyIndexes, + ResourcePoliciesMap: resourcePoliciesMap, + Routes: c.Routes, + NameServerGroups: c.NameServerGroups, + AllDNSRecords: c.AllDNSRecords, + AccountZones: c.AccountZones, + + RoutersMap: c.RoutersMap, + NetworkResources: c.NetworkResources, + + GroupIDToUserIDs: c.GroupIDToUserIDs, + AllowedUserIDs: c.AllowedUserIDs, + PostureFailedPeers: c.PostureFailedPeers, + } +} + +func (c *NetworkMapComponentsCompact) ToFull() *NetworkMapComponents { + peers := make(map[string]*nbpeer.Peer, len(c.PeerIndexes)) + for _, idx := range c.PeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peer := c.AllPeers[idx] + peers[peer.ID] = peer + } + } + + routerPeers := make(map[string]*nbpeer.Peer, len(c.RouterPeerIndexes)) + for _, idx := range c.RouterPeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peer := c.AllPeers[idx] + routerPeers[peer.ID] = peer + } + } + + groups := make(map[string]*Group, len(c.Groups)) + for id, gc := range c.Groups { + peerIDs := make([]string, 0, len(gc.PeerIndexes)) + for _, idx := range gc.PeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peerIDs = append(peerIDs, c.AllPeers[idx].ID) + } + } + groups[id] = &Group{ + ID: id, + Name: gc.Name, + Peers: peerIDs, + } + } + + policies := make([]*Policy, len(c.PolicyIndexes)) + for i, idx := range c.PolicyIndexes { + if idx >= 0 && idx < len(c.AllPolicies) { + policies[i] = c.AllPolicies[idx] + } + } + + var resourcePoliciesMap map[string][]*Policy + if len(c.ResourcePoliciesMap) > 0 { + resourcePoliciesMap = make(map[string][]*Policy, len(c.ResourcePoliciesMap)) + for resID, indexes := range c.ResourcePoliciesMap { + pols := make([]*Policy, 0, len(indexes)) + for _, idx := range indexes { + if idx >= 0 && idx < len(c.AllPolicies) { + pols = append(pols, c.AllPolicies[idx]) + } + } + resourcePoliciesMap[resID] = pols + } + } + + return &NetworkMapComponents{ + PeerID: c.PeerID, + Network: c.Network, + AccountSettings: c.AccountSettings, + DNSSettings: c.DNSSettings, + CustomZoneDomain: c.CustomZoneDomain, + + Peers: peers, + RouterPeers: routerPeers, + + Groups: groups, + Policies: policies, + Routes: c.Routes, + NameServerGroups: c.NameServerGroups, + AllDNSRecords: c.AllDNSRecords, + AccountZones: c.AccountZones, + + ResourcePoliciesMap: resourcePoliciesMap, + RoutersMap: c.RoutersMap, + NetworkResources: c.NetworkResources, + + GroupIDToUserIDs: c.GroupIDToUserIDs, + AllowedUserIDs: c.AllowedUserIDs, + PostureFailedPeers: c.PostureFailedPeers, + } +} diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go new file mode 100644 index 000000000..5cd41ff10 --- /dev/null +++ b/management/server/types/networkmap_components_correctness_test.go @@ -0,0 +1,1192 @@ +package types_test + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// scalableTestAccountWithoutDefaultPolicy creates an account without the blanket "Allow All" policy. +// Use this for tests that need to verify feature-specific connectivity in isolation. +func scalableTestAccountWithoutDefaultPolicy(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, false) +} + +// scalableTestAccount creates a realistic account with a blanket "Allow All" policy +// plus per-group policies, routes, network resources, posture checks, and DNS settings. +func scalableTestAccount(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, true) +} + +// buildScalableTestAccount is the core builder. When withDefaultPolicy is true it adds +// a blanket group-all <-> group-all allow rule; when false the only policies are the +// per-group ones, so tests can verify feature-specific connectivity in isolation. +func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) (*types.Account, map[string]struct{}) { + peers := make(map[string]*nbpeer.Peer, numPeers) + allGroupPeers := make([]string, 0, numPeers) + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, + IP: ip, + Key: fmt.Sprintf("key-%s", peerID), + DNSLabel: fmt.Sprintf("peer%d", i), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if i == numPeers-2 { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + } + + groups := make(map[string]*types.Group, numGroups+1) + groups["group-all"] = &types.Group{ID: "group-all", Name: "All", Peers: allGroupPeers} + + peersPerGroup := numPeers / numGroups + if peersPerGroup < 1 { + peersPerGroup = 1 + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + groupPeers := make([]string, 0, peersPerGroup) + start := g * peersPerGroup + end := start + peersPerGroup + if end > numPeers { + end = numPeers + } + for i := start; i < end; i++ { + groupPeers = append(groupPeers, fmt.Sprintf("peer-%d", i)) + } + groups[groupID] = &types.Group{ID: groupID, Name: fmt.Sprintf("Group %d", g), Peers: groupPeers} + } + + policies := make([]*types.Policy, 0, numGroups+2) + if withDefaultPolicy { + policies = append(policies, &types.Policy{ + ID: "policy-all", Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-all", Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-all"}, Destinations: []string{"group-all"}, + }}, + }) + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + dstGroup := fmt.Sprintf("group-%d", (g+1)%numGroups) + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-%d", g), Name: fmt.Sprintf("Policy %d", g), Enabled: true, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-%d", g), Name: fmt.Sprintf("Rule %d", g), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"8080"}, + Sources: []string{groupID}, Destinations: []string{dstGroup}, + }}, + }) + } + + if numGroups >= 2 { + policies = append(policies, &types.Policy{ + ID: "policy-drop", Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + } + + numRoutes := numGroups + if numRoutes > 20 { + numRoutes = 20 + } + routes := make(map[route.ID]*route.Route, numRoutes) + for r := range numRoutes { + routeID := route.ID(fmt.Sprintf("route-%d", r)) + peerIdx := (numPeers / 2) + r + if peerIdx >= numPeers { + peerIdx = numPeers - 1 + } + routePeerID := fmt.Sprintf("peer-%d", peerIdx) + groupID := fmt.Sprintf("group-%d", r%numGroups) + routes[routeID] = &route.Route{ + ID: routeID, + Network: netip.MustParsePrefix(fmt.Sprintf("10.%d.0.0/16", r)), + Peer: peers[routePeerID].Key, + PeerID: routePeerID, + Description: fmt.Sprintf("Route %d", r), + Enabled: true, + PeerGroups: []string{groupID}, + Groups: []string{"group-all"}, + AccessControlGroups: []string{groupID}, + AccountID: "test-account", + } + } + + numResources := numGroups / 2 + if numResources < 1 { + numResources = 1 + } + if numResources > 50 { + numResources = 50 + } + + networkResources := make([]*resourceTypes.NetworkResource, 0, numResources) + networksList := make([]*networkTypes.Network, 0, numResources) + networkRouters := make([]*routerTypes.NetworkRouter, 0, numResources) + + routingPeerStart := numPeers * 3 / 4 + for nr := range numResources { + netID := fmt.Sprintf("net-%d", nr) + resID := fmt.Sprintf("res-%d", nr) + routerPeerIdx := routingPeerStart + nr + if routerPeerIdx >= numPeers { + routerPeerIdx = numPeers - 1 + } + routerPeerID := fmt.Sprintf("peer-%d", routerPeerIdx) + + networksList = append(networksList, &networkTypes.Network{ID: netID, Name: fmt.Sprintf("Network %d", nr), AccountID: "test-account"}) + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: fmt.Sprintf("svc-%d.netbird.cloud", nr), + }) + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", nr), NetworkID: netID, Peer: routerPeerID, + Enabled: true, AccountID: "test-account", + }) + + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-res-%d", nr), Name: fmt.Sprintf("Resource Policy %d", nr), Enabled: true, + SourcePostureChecks: []string{"posture-check-ver"}, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-res-%d", nr), Name: fmt.Sprintf("Allow Resource %d", nr), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{fmt.Sprintf("group-%d", nr%numGroups)}, + DestinationResource: types.Resource{ID: resID}, + }}, + }) + } + + account := &types.Account{ + Id: "test-account", + Peers: peers, + Groups: groups, + Policies: policies, + Routes: routes, + Users: map[string]*types.User{ + "user-admin": {Id: "user-admin", Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: "test-account"}, + }, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-group-main": { + ID: "ns-group-main", Name: "Main NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: "posture-check-ver", Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: networkResources, + Networks: networksList, + NetworkRouters: networkRouters, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + validatedPeers := make(map[string]struct{}, numPeers) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if i != numPeers-1 { + validatedPeers[peerID] = struct{}{} + } + } + + return account, validatedPeers +} + +// componentsNetworkMap is a convenience wrapper for GetPeerNetworkMapFromComponents. +func componentsNetworkMap(account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + return account.GetPeerNetworkMapFromComponents( + context.Background(), peerID, nbdns.CustomZone{}, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 1. PEER VISIBILITY & GROUPS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PeerVisibility(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, len(validatedPeers)-1-len(nm.OfflinePeers), len(nm.Peers), "peer should see all other validated non-expired peers") +} + +func TestComponents_PeerDoesNotSeeItself(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-0", p.ID, "peer should not see itself") + } +} + +func TestComponents_IntraGroupConnectivity(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "peer-0 should see peer-5 from same group") +} + +func TestComponents_CrossGroupConnectivity(t *testing.T) { + // Without default policy, only per-group policies provide connectivity + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-10"], "peer-0 should see peer-10 from cross-group policy") +} + +func TestComponents_BidirectionalPolicy(t *testing.T) { + // Without default policy so bidirectional visibility comes only from per-group policies + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm20 := componentsNetworkMap(account, "peer-20", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm20) + + peer0SeesPeer20 := false + for _, p := range nm0.Peers { + if p.ID == "peer-20" { + peer0SeesPeer20 = true + } + } + peer20SeesPeer0 := false + for _, p := range nm20.Peers { + if p.ID == "peer-0" { + peer20SeesPeer0 = true + } + } + assert.True(t, peer0SeesPeer20, "peer-0 should see peer-20 via bidirectional policy") + assert.True(t, peer20SeesPeer0, "peer-20 should see peer-0 via bidirectional policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 2. PEER EXPIRATION & ACCOUNT SETTINGS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_ExpiredPeerInOfflineList(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-98"], "expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-98", p.ID, "expired peer should not be in active Peers") + } +} + +func TestComponents_ExpirationDisabledSetting(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.Settings.PeerLoginExpirationEnabled = false + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-98"], "with expiration disabled, peer-98 should be in active Peers") +} + +func TestComponents_LoginExpiration_PeerLevel(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + pastLogin := time.Now().Add(-2 * time.Hour) + account.Peers["peer-5"].LastLogin = &pastLogin + account.Peers["peer-5"].LoginExpirationEnabled = true + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-5"], "login-expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-5", p.ID, "login-expired peer should not be in active Peers") + } +} + +func TestComponents_NetworkSerial(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + account.Network.Serial = 42 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, uint64(42), nm.Network.Serial, "network serial should match") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 3. NON-VALIDATED PEERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NonValidatedPeerExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in Peers") + } + for _, p := range nm.OfflinePeers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in OfflinePeers") + } +} + +func TestComponents_NonValidatedTargetPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-99", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +func TestComponents_NonExistentPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-does-not-exist", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 4. POLICIES & FIREWALL RULES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_FirewallRulesGenerated(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.FirewallRules, "should have firewall rules from policies") +} + +func TestComponents_DropPolicyGeneratesDropRules(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasDropRule := false + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) { + hasDropRule = true + break + } + } + assert.True(t, hasDropRule, "should have at least one drop firewall rule") +} + +func TestComponents_DisabledPolicyIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, p := range account.Policies { + p.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers, "disabled policies should yield no peers") + assert.Empty(t, nm.FirewallRules, "disabled policies should yield no firewall rules") +} + +func TestComponents_PortPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has8080, has5432 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "8080" { + has8080 = true + } + if rule.Port == "5432" { + has5432 = true + } + } + assert.True(t, has8080, "should have firewall rule for port 8080") + assert.True(t, has5432, "should have firewall rule for port 5432 (drop policy)") +} + +func TestComponents_PortRangePolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + account.Peers["peer-0"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-port-range", Name: "Port Range", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-port-range", Name: "Port Range Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + PortRanges: []types.RulePortRange{{Start: 8000, End: 9000}}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasPortRange := false + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8000 && rule.PortRange.End == 9000 { + hasPortRange = true + break + } + } + assert.True(t, hasPortRange, "should have firewall rule with port range 8000-9000") +} + +func TestComponents_FirewallRuleDirection(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasIn, hasOut := false, false + for _, rule := range nm.FirewallRules { + if rule.Direction == types.FirewallRuleDirectionIN { + hasIn = true + } + if rule.Direction == types.FirewallRuleDirectionOUT { + hasOut = true + } + } + assert.True(t, hasIn, "should have inbound firewall rules") + assert.True(t, hasOut, "should have outbound firewall rules") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 5. ROUTES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_RoutesIncluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "should have routes") +} + +func TestComponents_DisabledRouteExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, r := range account.Routes { + r.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, r := range nm.Routes { + assert.True(t, r.Enabled, "only enabled routes should appear") + } +} + +func TestComponents_RoutesFirewallRulesForACG(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.RoutesFirewallRules, "should have route firewall rules for access-controlled routes") +} + +func TestComponents_HARouteDeduplication(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + + haNetwork := netip.MustParsePrefix("172.16.0.0/16") + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: haNetwork, PeerID: "peer-10", + Peer: account.Peers["peer-10"].Key, Enabled: true, Metric: 100, + Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, AccountID: "test-account", + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: haNetwork, PeerID: "peer-20", + Peer: account.Peers["peer-20"].Key, Enabled: true, Metric: 200, + Groups: []string{"group-all"}, PeerGroups: []string{"group-1"}, AccountID: "test-account", + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + haRoutes := 0 + for _, r := range nm.Routes { + if r.Network == haNetwork { + haRoutes++ + } + } + // Components deduplicates HA routes with the same HA unique ID, returning one entry per HA group + assert.Equal(t, 1, haRoutes, "HA routes with same network should be deduplicated into one entry") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 6. NETWORK RESOURCES & ROUTERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, routerPeerID, validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Peers, "router peer should see source peers") +} + +func TestComponents_NetworkResourceRoutes_SourcePeerSeesRouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs[routerPeerID], "source peer should see router peer for network resource") +} + +func TestComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for _, nr := range account.NetworkResources { + nr.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotNil(t, nm.Network) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 7. POSTURE CHECKS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PostureCheckFiltering_PassingPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "passing peer should have routes including resource routes") +} + +func TestComponents_PostureCheckFiltering_FailingPeer(t *testing.T) { + // peer-0 has version 0.40.0 (passes posture check >= 0.26.0) + // peer-1 has version 0.25.0 (fails posture check >= 0.26.0) + // Resource policies require posture-check-ver, so the failing peer + // should not see the router peer for those resources. + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm1) + + // The passing peer should have more peers visible (including resource router peers) + // than the failing peer, because the failing peer is excluded from resource policies. + assert.Greater(t, len(nm0.Peers), len(nm1.Peers), + "passing peer (0.40.0) should see more peers than failing peer (0.25.0) due to posture-gated resource policies") +} + +func TestComponents_MultiplePostureChecks(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(50, 2) + + // Keep only the posture-gated policy — remove per-group policies so connectivity is isolated + account.Policies = []*types.Policy{} + + // Set kernel version on peers so the OS posture check can evaluate + for _, p := range account.Peers { + p.Meta.KernelVersion = "5.15.0" + } + + account.PostureChecks = append(account.PostureChecks, &posture.Checks{ + ID: "posture-check-os", Name: "Check OS", + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "0.0.1"}}, + }, + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi Posture", Enabled: true, AccountID: "test-account", + SourcePostureChecks: []string{"posture-check-ver", "posture-check-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi Check Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 (0.40.0, kernel 5.15.0) passes both checks, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + assert.NotEmpty(t, nm0.Peers, "peer passing both posture checks should see destination peers") + + // peer-1 (0.25.0, kernel 5.15.0) fails version check, should NOT see group-1 peers + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm1) + assert.Empty(t, nm1.Peers, + "peer failing posture check should see no peers when posture-gated policy is the only connectivity") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 8. DNS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_DNSConfigEnabled(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable, "DNS should be enabled") + assert.NotEmpty(t, nm.DNSConfig.NameServerGroups, "should have nameserver groups") +} + +func TestComponents_DNSDisabledByManagementGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.DNSSettings.DisabledManagementGroups = []string{"group-all"} + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.DNSConfig.ServiceEnable, "DNS should be disabled for peer in disabled group") +} + +func TestComponents_DNSNameServerGroupDistribution(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.NameServerGroups["ns-group-0"] = &nbdns.NameServerGroup{ + ID: "ns-group-0", Name: "Group 0 NS", Enabled: true, Groups: []string{"group-0"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + hasGroup0NS := false + for _, ns := range nm0.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NS = true + } + } + assert.True(t, hasGroup0NS, "peer-0 in group-0 should receive ns-group-0") + + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasGroup0NSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NSForPeer10 = true + } + } + assert.False(t, hasGroup0NSForPeer10, "peer-10 in group-1 should NOT receive ns-group-0") +} + +func TestComponents_DNSCustomZone(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + customZone := nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer0.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-0"].IP.String()}, + {Name: "peer1.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-1"].IP.String()}, + }, + } + + nm := account.GetPeerNetworkMapFromComponents( + context.Background(), "peer-0", customZone, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 9. SSH +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_SSHPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root"}}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled for destination peer of SSH policy") +} + +func TestComponents_SSHNotEnabledWithoutPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.EnableSSH, "SSH should not be enabled without SSH policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 10. CROSS-PEER CONSISTENCY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_AllPeersGetValidMaps verifies that every validated peer gets a +// non-nil map with a consistent network serial and non-empty peer list. +func TestComponents_AllPeersGetValidMaps(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for peerID := range account.Peers { + if _, validated := validatedPeers[peerID]; !validated { + continue + } + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + assert.NotEmpty(t, nm.Peers, "validated peer %s should see other peers", peerID) + } +} + +// TestComponents_LargeScaleMapGeneration verifies that components can generate maps +// at larger scales without errors and with consistent output. +func TestComponents_LargeScaleMapGeneration(t *testing.T) { + scales := []struct{ peers, groups int }{ + {500, 20}, + {1000, 50}, + } + for _, s := range scales { + t.Run(fmt.Sprintf("%dpeers_%dgroups", s.peers, s.groups), func(t *testing.T) { + account, validatedPeers := scalableTestAccount(s.peers, s.groups) + testPeers := []string{"peer-0", fmt.Sprintf("peer-%d", s.peers/4), fmt.Sprintf("peer-%d", s.peers/2)} + for _, peerID := range testPeers { + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.NotEmpty(t, nm.Peers, "peer %s should see other peers at scale", peerID) + assert.NotEmpty(t, nm.Routes, "peer %s should have routes at scale", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// 11. PEER-AS-RESOURCE POLICIES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerAsSourceResource verifies that a policy with SourceResource.Type=Peer +// targets only that specific peer as the source. +func TestComponents_PeerAsSourceResource(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-src", Name: "Peer Source Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-src", Name: "Peer Source Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + SourceResource: types.Resource{ID: "peer-0", Type: types.ResourceTypePeer}, + Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 is the source resource, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + has443 := false + for _, rule := range nm0.FirewallRules { + if rule.Port == "443" { + has443 = true + break + } + } + assert.True(t, has443, "peer-0 as source resource should have port 443 rule") +} + +// TestComponents_PeerAsDestinationResource verifies that a policy with DestinationResource.Type=Peer +// targets only that specific peer as the destination. +func TestComponents_PeerAsDestinationResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-dst", Name: "Peer Dest Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-dst", Name: "Peer Dest Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + Sources: []string{"group-0"}, + DestinationResource: types.Resource{ID: "peer-15", Type: types.ResourceTypePeer}, + }}, + }) + + // peer-0 is in group-0 (source), should see peer-15 as destination + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + peerIDs := make(map[string]bool, len(nm0.Peers)) + for _, p := range nm0.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-15"], "peer-0 should see peer-15 via peer-as-destination-resource policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 12. MULTIPLE RULES PER POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRulesPerPolicy verifies a policy with multiple rules generates +// firewall rules for each. +func TestComponents_MultipleRulesPerPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-rule", Name: "Multi Rule Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-http", Name: "Allow HTTP", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"80"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-https", Name: "Allow HTTPS", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"443"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has80, has443 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "80" { + has80 = true + } + if rule.Port == "443" { + has443 = true + } + } + assert.True(t, has80, "should have firewall rule for port 80 from first rule") + assert.True(t, has443, "should have firewall rule for port 443 from second rule") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 13. SSH AUTHORIZED USERS CONTENT +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_SSHAuthorizedUsersContent verifies that SSH policies populate +// the AuthorizedUsers map with the correct users and machine mappings. +func TestComponents_SSHAuthorizedUsersContent(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}} + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root", "admin"}}, + }}, + }) + + // peer-10 is in group-1 (destination) + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled") + assert.NotNil(t, nm.AuthorizedUsers, "AuthorizedUsers should not be nil") + assert.NotEmpty(t, nm.AuthorizedUsers, "AuthorizedUsers should have entries") + + // Check that "root" machine user mapping exists + _, hasRoot := nm.AuthorizedUsers["root"] + _, hasAdmin := nm.AuthorizedUsers["admin"] + assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping") +} + +// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with +// SSHEnabled peer implies legacy SSH access. +func TestComponents_SSHLegacyImpliedSSH(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Enable SSH on the destination peer + account.Peers["peer-10"].SSHEnabled = true + + // The default "Allow All" policy with Protocol=ALL + SSHEnabled peer should imply SSH + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be implied by ALL protocol policy with SSHEnabled peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 14. ROUTE DEFAULT PERMIT (no AccessControlGroups) +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_RouteDefaultPermit verifies that a route without AccessControlGroups +// generates default permit firewall rules (0.0.0.0/0 source). +func TestComponents_RouteDefaultPermit(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Add a route without ACGs — this peer is the routing peer + routingPeerID := "peer-5" + account.Routes["route-no-acg"] = &route.Route{ + ID: "route-no-acg", Network: netip.MustParsePrefix("192.168.99.0/24"), + PeerID: routingPeerID, Peer: account.Peers[routingPeerID].Key, + Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, + AccessControlGroups: []string{}, + AccountID: "test-account", + } + + // The routing peer should get default permit route firewall rules + nm := componentsNetworkMap(account, routingPeerID, validatedPeers) + require.NotNil(t, nm) + + hasDefaultPermit := false + for _, rfr := range nm.RoutesFirewallRules { + for _, src := range rfr.SourceRanges { + if src == "0.0.0.0/0" || src == "::/0" { + hasDefaultPermit = true + break + } + } + } + assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 15. MULTIPLE ROUTERS PER NETWORK +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRoutersPerNetwork verifies that a network resource +// with multiple routers provides routes through all available routers. +func TestComponents_MultipleRoutersPerNetwork(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-multi-router" + resID := "res-multi-router" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Multi Router Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "multi-svc.netbird.cloud", + }) + account.NetworkRouters = append(account.NetworkRouters, + &routerTypes.NetworkRouter{ID: "router-a", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", Metric: 100}, + &routerTypes.NetworkRouter{ID: "router-b", NetworkID: netID, Peer: "peer-15", Enabled: true, AccountID: "test-account", Metric: 200}, + ) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-router-res", Name: "Multi Router Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-multi-router-res", Name: "Allow Multi Router", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is in group-0 (source), should see both router peers + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see router-a (peer-5)") + assert.True(t, peerIDs["peer-15"], "source peer should see router-b (peer-15)") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 16. PEER-AS-NAMESERVER EXCLUSION +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerIsNameserverExcludedFromNSGroup verifies that a peer serving +// as a nameserver does not receive its own NS group in DNS config. +func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // peer-0 has IP 100.64.0.0 — make it a nameserver + nsIP := account.Peers["peer-0"].IP + account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{ + ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasSelfNS := false + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasSelfNS = true + } + } + assert.False(t, hasSelfNS, "peer serving as nameserver should NOT receive its own NS group") + + // peer-10 is NOT the nameserver, should receive the NS group + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasNSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasNSForPeer10 = true + } + } + assert.True(t, hasNSForPeer10, "non-nameserver peer should receive the NS group") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 17. DOMAIN NETWORK RESOURCES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DomainNetworkResource verifies that domain-based network resources +// produce routes with the correct domain configuration. +func TestComponents_DomainNetworkResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-domain" + resID := "res-domain" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Domain Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "api.example.com", Type: "domain", + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-domain", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain-res", Name: "Domain Resource Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-domain-res", Name: "Allow Domain", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is source, should get route to the domain resource via peer-5 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see domain resource router peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 18. DISABLED RULE WITHIN ENABLED POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DisabledRuleInEnabledPolicy verifies that a disabled rule within +// an enabled policy does not generate firewall rules. +func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-mixed-rules", Name: "Mixed Rules", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-enabled", Name: "Enabled Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3000"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-disabled", Name: "Disabled Rule", Enabled: false, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3001"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has3000, has3001 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "3000" { + has3000 = true + } + if rule.Port == "3001" { + has3001 = true + } + } + assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000") + assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001") +} diff --git a/management/server/types/settings.go b/management/server/types/settings.go index a94e01b78..4ea79ec72 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -47,6 +47,11 @@ type Settings struct { // NetworkRange is the custom network range for that account NetworkRange netip.Prefix `gorm:"serializer:json"` + // PeerExposeEnabled enables or disables peer-initiated service expose + PeerExposeEnabled bool + // PeerExposeGroups list of peer group IDs allowed to expose services + PeerExposeGroups []string `gorm:"serializer:json"` + // Extra is a dictionary of Account settings Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` @@ -56,6 +61,10 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + // AutoUpdateAlways when true, updates are installed automatically in the background; + // when false, updates require user interaction from the UI + AutoUpdateAlways bool `gorm:"default:false"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -80,10 +89,13 @@ func (s *Settings) Copy() *Settings { PeerInactivityExpiration: s.PeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + PeerExposeEnabled: s.PeerExposeEnabled, + PeerExposeGroups: slices.Clone(s.PeerExposeGroups), LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + AutoUpdateAlways: s.AutoUpdateAlways, EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, } diff --git a/management/server/user.go b/management/server/user.go index 924efc1e4..c1f984f2f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -417,6 +417,10 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + // @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() @@ -457,6 +461,10 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return err } + if targetUser.AccountID != accountID { + return status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return status.NewAdminPermissionError() } @@ -496,6 +504,10 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -523,6 +535,10 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } + if targetUser.AccountID != accountID { + return nil, status.NewPermissionDeniedError() + } + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -742,6 +758,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact if err != nil { return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err) } + + // Ensure the initiator still has admin privileges + if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() { + return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing") + } initiatorUser = freshInitiator } @@ -759,9 +780,15 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact updatedUser.Role = update.Role updatedUser.Blocked = update.Blocked updatedUser.AutoGroups = update.AutoGroups - // these two fields can't be set via API, only via direct call to the method + // these fields can't be set via API, only via direct call to the method updatedUser.Issued = update.Issued updatedUser.IntegrationReference = update.IntegrationReference + if update.Name != "" { + updatedUser.Name = update.Name + } + if update.Email != "" { + updatedUser.Email = update.Email + } var transferredOwnerRole bool result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) @@ -872,10 +899,6 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse return nil } - if !initiatorUser.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only admins and owners can update users") - } - if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } diff --git a/management/server/user_test.go b/management/server/user_test.go index 72a19a9a5..8fdfbd633 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -336,6 +336,104 @@ func TestUser_GetAllPATs(t *testing.T) { assert.Equal(t, 2, len(pats)) } +func TestUser_PAT_CrossAccountProtection(t *testing.T) { + const ( + accountAID = "accountA" + accountBID = "accountB" + userAID = "userA" + adminBID = "adminB" + serviceUserBID = "serviceUserB" + regularUserBID = "regularUserB" + tokenBID = "tokenB1" + hashedTokenB = "SoMeHaShEdToKeNB" + ) + + setupStore := func(t *testing.T) (*DefaultAccountManager, func()) { + t.Helper() + + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err, "creating store") + + accountA := newAccountWithId(context.Background(), accountAID, userAID, "", "", "", false) + require.NoError(t, s.SaveAccount(context.Background(), accountA)) + + accountB := newAccountWithId(context.Background(), accountBID, adminBID, "", "", "", false) + accountB.Users[serviceUserBID] = &types.User{ + Id: serviceUserBID, + AccountID: accountBID, + IsServiceUser: true, + ServiceUserName: "svcB", + Role: types.UserRoleAdmin, + PATs: map[string]*types.PersonalAccessToken{ + tokenBID: { + ID: tokenBID, + HashedToken: hashedTokenB, + }, + }, + } + accountB.Users[regularUserBID] = &types.User{ + Id: regularUserBID, + AccountID: accountBID, + Role: types.UserRoleUser, + } + require.NoError(t, s.SaveAccount(context.Background(), accountB)) + + pm := permissions.NewManager(s) + am := &DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: pm, + } + return am, cleanup + } + + t.Run("CreatePAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.CreatePAT(context.Background(), accountAID, userAID, serviceUserBID, "xss-token", 7) + require.Error(t, err, "cross-account CreatePAT must fail") + + _, err = am.CreatePAT(context.Background(), accountAID, userAID, regularUserBID, "xss-token", 7) + require.Error(t, err, "cross-account CreatePAT for regular user must fail") + + _, err = am.CreatePAT(context.Background(), accountBID, adminBID, serviceUserBID, "legit-token", 7) + require.NoError(t, err, "same-account CreatePAT should succeed") + }) + + t.Run("DeletePAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + err := am.DeletePAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID) + require.Error(t, err, "cross-account DeletePAT must fail") + }) + + t.Run("GetPAT for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.GetPAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID) + require.Error(t, err, "cross-account GetPAT must fail") + }) + + t.Run("GetAllPATs for user in different account is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.GetAllPATs(context.Background(), accountAID, userAID, serviceUserBID) + require.Error(t, err, "cross-account GetAllPATs must fail") + }) + + t.Run("CreatePAT with forged accountID targeting foreign user is denied", func(t *testing.T) { + am, cleanup := setupStore(t) + t.Cleanup(cleanup) + + _, err := am.CreatePAT(context.Background(), accountAID, userAID, adminBID, "forged", 7) + require.Error(t, err, "forged accountID CreatePAT must fail") + }) +} + func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way user := types.User{ @@ -2032,27 +2130,6 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) { }) } -func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) { - groupsMap := map[string]*types.Group{} - - initiator := &types.User{ - Id: "initiator", - Role: types.UserRoleUser, - } - oldUser := &types.User{ - Id: "target", - Role: types.UserRoleUser, - } - update := &types.User{ - Id: "target", - Role: types.UserRoleOwner, - } - - err := validateUserUpdate(groupsMap, initiator, oldUser, update) - require.Error(t, err, "regular user should not be able to promote to owner") - assert.Contains(t, err.Error(), "only admins and owners can update users") -} - func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) { s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) require.NoError(t, err) @@ -2109,7 +2186,7 @@ func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) { }) require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted") - assert.Contains(t, err.Error(), "only admins and owners can update users") + assert.Contains(t, err.Error(), "initiator role was changed during request processing") targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID) require.NoError(t, err) diff --git a/proxy/Dockerfile b/proxy/Dockerfile index 096c71f21..e64680fd6 100644 --- a/proxy/Dockerfile +++ b/proxy/Dockerfile @@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug COPY netbird-proxy /go/bin/netbird-proxy COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/group /etc/group -COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird +COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs USER netbird:netbird ENV HOME=/var/lib/netbird diff --git a/proxy/Dockerfile.multistage b/proxy/Dockerfile.multistage index 2e3ac3561..01e342c0e 100644 --- a/proxy/Dockerfile.multistage +++ b/proxy/Dockerfile.multistage @@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/group /etc/group -COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird +COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs USER netbird:netbird ENV HOME=/var/lib/netbird diff --git a/proxy/auth/auth.go b/proxy/auth/auth.go index 14caa03b3..ca9c260b7 100644 --- a/proxy/auth/auth.go +++ b/proxy/auth/auth.go @@ -13,10 +13,11 @@ import ( type Method string -var ( +const ( MethodPassword Method = "password" MethodPIN Method = "pin" MethodOIDC Method = "oidc" + MethodHeader Method = "header" ) func (m Method) String() string { diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 121621109..ec8980ad9 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strconv" "syscall" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -34,25 +35,37 @@ var ( ) var ( - debugLogs bool - mgmtAddr string - addr string - proxyDomain string - certDir string - acmeCerts bool - acmeAddr string - acmeDir string - acmeChallengeType string - debugEndpoint bool - debugEndpointAddr string - healthAddr string - forwardedProto string - trustedProxies string - certFile string - certKeyFile string - certLockMethod string - wgPort int - proxyProtocol bool + logLevel string + debugLogs bool + mgmtAddr string + addr string + proxyDomain string + maxDialTimeout time.Duration + maxSessionIdleTimeout time.Duration + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + acmeEABKID string + acmeEABHMACKey string + acmeChallengeType string + debugEndpoint bool + debugEndpointAddr string + healthAddr string + forwardedProto string + trustedProxies string + certFile string + certKeyFile string + certLockMethod string + wildcardCertDir string + wgPort uint16 + proxyProtocol bool + preSharedKey string + supportsCustomPorts bool + requireSubdomain bool + geoDataDir string + crowdsecAPIURL string + crowdsecAPIKey string ) var rootCmd = &cobra.Command{ @@ -65,7 +78,9 @@ var rootCmd = &cobra.Command{ } func init() { + rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", envStringOrDefault("NB_PROXY_LOG_LEVEL", "info"), "Log level: panic, fatal, error, warn, info, debug, trace") rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") + _ = rootCmd.PersistentFlags().MarkDeprecated("debug", "use --log-level instead") rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to") rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on") rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com") @@ -73,6 +88,8 @@ func init() { rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically") rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)") rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory") + rootCmd.Flags().StringVar(&acmeEABKID, "acme-eab-kid", envStringOrDefault("NB_PROXY_ACME_EAB_KID", ""), "ACME EAB KID for account registration") + rootCmd.Flags().StringVar(&acmeEABHMACKey, "acme-eab-hmac-key", envStringOrDefault("NB_PROXY_ACME_EAB_HMAC_KEY", ""), "ACME EAB HMAC key for account registration") rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)") rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint") rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint") @@ -82,8 +99,17 @@ func init() { rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory") rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") - rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") + rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (.crt/.key). Wildcard patterns are extracted from SANs automatically") + rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") + rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") + rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") + rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain") + rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)") + rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)") + rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)") + rootCmd.Flags().StringVar(&crowdsecAPIURL, "crowdsec-api-url", envStringOrDefault("NB_PROXY_CROWDSEC_API_URL", ""), "CrowdSec LAPI URL for IP reputation checks") + rootCmd.Flags().StringVar(&crowdsecAPIKey, "crowdsec-api-key", envStringOrDefault("NB_PROXY_CROWDSEC_API_KEY", ""), "CrowdSec bouncer API key") } // Execute runs the root command. @@ -109,7 +135,7 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken) } - level := "error" + level := logLevel if debugLogs { level = "debug" } @@ -147,6 +173,8 @@ func runServer(cmd *cobra.Command, args []string) error { GenerateACMECertificates: acmeCerts, ACMEChallengeAddress: acmeAddr, ACMEDirectory: acmeDir, + ACMEEABKID: acmeEABKID, + ACMEEABHMACKey: acmeEABHMACKey, ACMEChallengeType: acmeChallengeType, DebugEndpointEnabled: debugEndpoint, DebugEndpointAddress: debugEndpointAddr, @@ -154,18 +182,23 @@ func runServer(cmd *cobra.Command, args []string) error { ForwardedProto: forwardedProto, TrustedProxies: parsedTrustedProxies, CertLockMethod: nbacme.CertLockMethod(certLockMethod), + WildcardCertDir: wildcardCertDir, WireguardPort: wgPort, ProxyProtocol: proxyProtocol, + PreSharedKey: preSharedKey, + SupportsCustomPorts: supportsCustomPorts, + RequireSubdomain: requireSubdomain, + MaxDialTimeout: maxDialTimeout, + MaxSessionIdleTimeout: maxSessionIdleTimeout, + GeoDataDir: geoDataDir, + CrowdSecAPIURL: crowdsecAPIURL, + CrowdSecAPIKey: crowdsecAPIKey, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) defer stop() - if err := srv.ListenAndServe(ctx, addr); err != nil { - logger.Error(err) - return err - } - return nil + return srv.ListenAndServe(ctx, addr) } func envBoolOrDefault(key string, def bool) bool { @@ -175,6 +208,7 @@ func envBoolOrDefault(key string, def bool) bool { } parsed, err := strconv.ParseBool(v) if err != nil { + log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def) return def } return parsed @@ -188,13 +222,27 @@ func envStringOrDefault(key string, def string) string { return v } -func envIntOrDefault(key string, def int) int { +func envUint16OrDefault(key string, def uint16) uint16 { v, exists := os.LookupEnv(key) if !exists { return def } - parsed, err := strconv.Atoi(v) + parsed, err := strconv.ParseUint(v, 10, 16) if err != nil { + log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def) + return def + } + return uint16(parsed) +} + +func envDurationOrDefault(key string, def time.Duration) time.Duration { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + parsed, err := time.ParseDuration(v) + if err != nil { + log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def) return def } return parsed diff --git a/proxy/cmd/proxy/main.go b/proxy/cmd/proxy/main.go index 14e540a2e..16e7e8ac2 100644 --- a/proxy/cmd/proxy/main.go +++ b/proxy/cmd/proxy/main.go @@ -1,8 +1,13 @@ package main import ( + "net/http" + // nolint:gosec + _ "net/http/pprof" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/proxy/cmd/proxy/cmd" ) @@ -21,6 +26,9 @@ var ( ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() cmd.SetVersionInfo(Version, Commit, BuildDate, GoVersion) cmd.Execute() } diff --git a/proxy/handle_mapping_stream_test.go b/proxy/handle_mapping_stream_test.go index d2ad3f67e..cb16c0814 100644 --- a/proxy/handle_mapping_stream_test.go +++ b/proxy/handle_mapping_stream_test.go @@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun func (m *mockMappingStream) SendMsg(any) error { return nil } func (m *mockMappingStream) RecvMsg(any) error { return nil } +func closedChan() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) { checker := health.NewChecker(nil, nil) s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { func TestHandleMappingStream_NilHealthChecker(t *testing.T) { s := &Server{ - Logger: log.StandardLogger(), + Logger: log.StandardLogger(), + routerReady: closedChan(), } stream := &mockMappingStream{ diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 9e204be65..3283f61db 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -2,26 +2,81 @@ package accesslog import ( "context" + "maps" "net/netip" + "sync" + "sync/atomic" "time" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) +const ( + requestThreshold = 10000 // Log every 10k requests + bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB + usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour + usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours + logSendTimeout = 10 * time.Second + + // denyCooldown is the min interval between deny log entries per service+reason + // to prevent flooding from denied connections (e.g. UDP packets from blocked IPs). + denyCooldown = 10 * time.Second + + // maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS. + maxDenyBuckets = 10000 + + // maxLogWorkers caps concurrent gRPC send goroutines. + maxLogWorkers = 4096 +) + +type domainUsage struct { + requestCount int64 + requestStartTime time.Time + + bytesTransferred int64 + bytesStartTime time.Time + + lastActivity time.Time // Track last activity for cleanup +} + type gRPCClient interface { SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) } +// denyBucketKey identifies a rate-limited deny log stream. +type denyBucketKey struct { + ServiceID types.ServiceID + Reason string +} + +// denyBucket tracks rate-limited deny log entries. +type denyBucket struct { + lastLogged time.Time + suppressed int64 +} + // Logger sends access log entries to the management server via gRPC. type Logger struct { client gRPCClient logger *log.Logger trustedProxies []netip.Prefix + + usageMux sync.Mutex + domainUsage map[string]*domainUsage + + denyMu sync.Mutex + denyBuckets map[denyBucketKey]*denyBucket + + logSem chan struct{} + cleanupCancel context.CancelFunc + dropped atomic.Int64 } // NewLogger creates a new access log Logger. The trustedProxies parameter @@ -31,29 +86,135 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre if logger == nil { logger = log.StandardLogger() } - return &Logger{ + + ctx, cancel := context.WithCancel(context.Background()) + l := &Logger{ client: client, logger: logger, trustedProxies: trustedProxies, + domainUsage: make(map[string]*domainUsage), + denyBuckets: make(map[denyBucketKey]*denyBucket), + logSem: make(chan struct{}, maxLogWorkers), + cleanupCancel: cancel, + } + + // Start background cleanup routine + go l.cleanupStaleUsage(ctx) + + return l +} + +// Close stops the cleanup routine. Should be called during graceful shutdown. +func (l *Logger) Close() { + if l.cleanupCancel != nil { + l.cleanupCancel() } } type logEntry struct { ID string - AccountID string - ServiceId string + AccountID types.AccountID + ServiceID types.ServiceID Host string Path string DurationMs int64 Method string ResponseCode int32 - SourceIp string + SourceIP netip.Addr AuthMechanism string - UserId string + UserID string AuthSuccess bool + BytesUpload int64 + BytesDownload int64 + Protocol Protocol + Metadata map[string]string } -func (l *Logger) log(ctx context.Context, entry logEntry) { +// Protocol identifies the transport protocol of an access log entry. +type Protocol string + +const ( + ProtocolHTTP Protocol = "http" + ProtocolTCP Protocol = "tcp" + ProtocolUDP Protocol = "udp" + ProtocolTLS Protocol = "tls" +) + +// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry. +type L4Entry struct { + AccountID types.AccountID + ServiceID types.ServiceID + Protocol Protocol + Host string // SNI hostname or listen address + SourceIP netip.Addr + DurationMs int64 + BytesUpload int64 + BytesDownload int64 + // DenyReason, when non-empty, indicates the connection was denied. + // Values match the HTTP auth mechanism strings: "ip_restricted", + // "country_restricted", "geo_unavailable", "crowdsec_ban", etc. + DenyReason string + // Metadata carries extra context about the connection (e.g. CrowdSec verdict). + Metadata map[string]string +} + +// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). +// The call is non-blocking: the gRPC send happens in a background goroutine. +func (l *Logger) LogL4(entry L4Entry) { + le := logEntry{ + ID: xid.New().String(), + AccountID: entry.AccountID, + ServiceID: entry.ServiceID, + Protocol: entry.Protocol, + Host: entry.Host, + SourceIP: entry.SourceIP, + DurationMs: entry.DurationMs, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, + Metadata: maps.Clone(entry.Metadata), + } + if entry.DenyReason != "" { + if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) { + return + } + le.AuthMechanism = entry.DenyReason + le.AuthSuccess = false + } + l.log(le) + l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) +} + +// allowDenyLog rate-limits deny log entries per service+reason combination. +func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool { + key := denyBucketKey{ServiceID: serviceID, Reason: reason} + now := time.Now() + + l.denyMu.Lock() + defer l.denyMu.Unlock() + + b, ok := l.denyBuckets[key] + if !ok { + if len(l.denyBuckets) >= maxDenyBuckets { + return false + } + l.denyBuckets[key] = &denyBucket{lastLogged: now} + return true + } + + if now.Sub(b.lastLogged) >= denyCooldown { + if b.suppressed > 0 { + l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason) + } + b.lastLogged = now + b.suppressed = 0 + return true + } + + b.suppressed++ + return false +} + +func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message // (although it should still get logged in the event of an error), @@ -62,44 +223,162 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { // There is also a chance that log messages will arrive at // the server out of order; however, the timestamp should // allow for resolving that on the server. - now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. + now := timestamppb.Now() + select { + case l.logSem <- struct{}{}: + default: + total := l.dropped.Add(1) + l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total) + return + } go func() { - logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer func() { <-l.logSem }() + logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) defer cancel() + // Only OIDC sessions have a meaningful user identity. if entry.AuthMechanism != auth.MethodOIDC.String() { - entry.UserId = "" + entry.UserID = "" } + + var sourceIP string + if entry.SourceIP.IsValid() { + sourceIP = entry.SourceIP.String() + } + if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{ Log: &proto.AccessLog{ LogId: entry.ID, - AccountId: entry.AccountID, + AccountId: string(entry.AccountID), Timestamp: now, - ServiceId: entry.ServiceId, + ServiceId: string(entry.ServiceID), Host: entry.Host, Path: entry.Path, DurationMs: entry.DurationMs, Method: entry.Method, ResponseCode: entry.ResponseCode, - SourceIp: entry.SourceIp, + SourceIp: sourceIP, AuthMechanism: entry.AuthMechanism, - UserId: entry.UserId, + UserId: entry.UserID, AuthSuccess: entry.AuthSuccess, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, + Protocol: string(entry.Protocol), + Metadata: entry.Metadata, }, }); err != nil { - // If it fails to send on the gRPC connection, then at least log it to the error log. l.logger.WithFields(log.Fields{ - "service_id": entry.ServiceId, + "service_id": entry.ServiceID, "host": entry.Host, "path": entry.Path, "duration": entry.DurationMs, "method": entry.Method, "response_code": entry.ResponseCode, - "source_ip": entry.SourceIp, + "source_ip": sourceIP, "auth_mechanism": entry.AuthMechanism, - "user_id": entry.UserId, + "user_id": entry.UserID, "auth_success": entry.AuthSuccess, "error": err, }).Error("Error sending access log on gRPC connection") } }() } + +// trackUsage records request and byte counts per domain, logging when thresholds are hit. +func (l *Logger) trackUsage(domain string, bytesTransferred int64) { + if domain == "" { + return + } + + l.usageMux.Lock() + defer l.usageMux.Unlock() + + now := time.Now() + usage, exists := l.domainUsage[domain] + if !exists { + usage = &domainUsage{ + requestStartTime: now, + bytesStartTime: now, + lastActivity: now, + } + l.domainUsage[domain] = usage + } + + usage.lastActivity = now + + usage.requestCount++ + if usage.requestCount >= requestThreshold { + elapsed := time.Since(usage.requestStartTime) + l.logger.WithFields(log.Fields{ + "domain": domain, + "requests": usage.requestCount, + "duration": elapsed.String(), + }).Infof("domain %s had %d requests over %s", domain, usage.requestCount, elapsed) + + usage.requestCount = 0 + usage.requestStartTime = now + } + + usage.bytesTransferred += bytesTransferred + if usage.bytesTransferred >= bytesThreshold { + elapsed := time.Since(usage.bytesStartTime) + bytesInGB := float64(usage.bytesTransferred) / (1024 * 1024 * 1024) + l.logger.WithFields(log.Fields{ + "domain": domain, + "bytes": usage.bytesTransferred, + "bytes_gb": bytesInGB, + "duration": elapsed.String(), + }).Infof("domain %s transferred %.2f GB over %s", domain, bytesInGB, elapsed) + + usage.bytesTransferred = 0 + usage.bytesStartTime = now + } +} + +// cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive. +func (l *Logger) cleanupStaleUsage(ctx context.Context) { + ticker := time.NewTicker(usageCleanupPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + l.cleanupDomainUsage(now) + l.cleanupDenyBuckets(now) + } + } +} + +func (l *Logger) cleanupDomainUsage(now time.Time) { + l.usageMux.Lock() + defer l.usageMux.Unlock() + + removed := 0 + for domain, usage := range l.domainUsage { + if now.Sub(usage.lastActivity) > usageInactiveWindow { + delete(l.domainUsage, domain) + removed++ + } + } + if removed > 0 { + l.logger.Debugf("cleaned up %d stale domain usage entries", removed) + } +} + +func (l *Logger) cleanupDenyBuckets(now time.Time) { + l.denyMu.Lock() + defer l.denyMu.Unlock() + + removed := 0 + for key, bucket := range l.denyBuckets { + if now.Sub(bucket.lastLogged) > usageInactiveWindow { + delete(l.denyBuckets, key) + removed++ + } + } + if removed > 0 { + l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed) + } +} diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index dd4798975..5a0684c19 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/proxy/web" ) +// Middleware wraps an HTTP handler to log access entries and resolve client IPs. func (l *Logger) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip logging for internal proxy assets (CSS, JS, etc.) @@ -32,6 +33,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { status: http.StatusOK, } + var bytesRead int64 + if r.Body != nil { + r.Body = &bodyCounter{ + ReadCloser: r.Body, + bytesRead: &bytesRead, + } + } + // Resolve the source IP using trusted proxy configuration before passing // the request on, as the proxy will modify forwarding headers. sourceIp := extractSourceIP(r, l.trustedProxies) @@ -39,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { // Create a mutable struct to capture data from downstream handlers. // We pass a pointer in the context - the pointer itself flows down immutably, // but the struct it points to can be mutated by inner handlers. - capturedData := &proxy.CapturedData{RequestID: requestID} + capturedData := proxy.NewCapturedData(requestID) capturedData.SetClientIP(sourceIp) + ctx := proxy.WithCapturedData(r.Context(), capturedData) start := time.Now() @@ -53,23 +63,33 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { host = r.Host } + bytesUpload := bytesRead + bytesDownload := sw.bytesWritten + entry := logEntry{ ID: requestID, - ServiceId: capturedData.GetServiceId(), - AccountID: string(capturedData.GetAccountId()), + ServiceID: capturedData.GetServiceID(), + AccountID: capturedData.GetAccountID(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), Method: r.Method, ResponseCode: int32(sw.status), - SourceIp: sourceIp, + SourceIP: sourceIp, AuthMechanism: capturedData.GetAuthMethod(), - UserId: capturedData.GetUserID(), + UserID: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, + BytesUpload: bytesUpload, + BytesDownload: bytesDownload, + Protocol: ProtocolHTTP, + Metadata: capturedData.GetMetadata(), } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", - requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) + requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID()) - l.log(r.Context(), entry) + l.log(entry) + + // Track usage for cost monitoring (upload + download) by domain + l.trackUsage(host, bytesUpload+bytesDownload) }) } diff --git a/proxy/internal/accesslog/requestip.go b/proxy/internal/accesslog/requestip.go index f111c1322..30c483fd9 100644 --- a/proxy/internal/accesslog/requestip.go +++ b/proxy/internal/accesslog/requestip.go @@ -11,6 +11,6 @@ import ( // proxy configuration. When trustedProxies is non-empty and the direct // connection is from a trusted source, it walks X-Forwarded-For right-to-left // skipping trusted IPs. Otherwise it returns RemoteAddr directly. -func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string { +func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr { return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies) } diff --git a/proxy/internal/accesslog/statuswriter.go b/proxy/internal/accesslog/statuswriter.go index 43cda59f9..24f7b35e9 100644 --- a/proxy/internal/accesslog/statuswriter.go +++ b/proxy/internal/accesslog/statuswriter.go @@ -1,18 +1,39 @@ package accesslog import ( + "io" + "github.com/netbirdio/netbird/proxy/internal/responsewriter" ) -// statusWriter captures the HTTP status code from WriteHeader calls. +// statusWriter captures the HTTP status code and bytes written from responses. // It embeds responsewriter.PassthroughWriter which handles all the optional // interfaces (Hijacker, Flusher, Pusher) automatically. type statusWriter struct { *responsewriter.PassthroughWriter - status int + status int + bytesWritten int64 } func (w *statusWriter) WriteHeader(status int) { w.status = status w.PassthroughWriter.WriteHeader(status) } + +func (w *statusWriter) Write(b []byte) (int, error) { + n, err := w.PassthroughWriter.Write(b) + w.bytesWritten += int64(n) + return n, err +} + +// bodyCounter wraps an io.ReadCloser and counts bytes read from the request body. +type bodyCounter struct { + io.ReadCloser + bytesRead *int64 +} + +func (bc *bodyCounter) Read(p []byte) (int, error) { + n, err := bc.ReadCloser.Read(p) + *bc.bytesRead += int64(n) + return n, err +} diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index a663b8138..a4a220ed7 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -5,10 +5,16 @@ import ( "crypto/tls" "crypto/x509" "encoding/asn1" + "encoding/base64" "encoding/binary" + "encoding/pem" "fmt" + "math/rand/v2" "net" + "os" + "path/filepath" "slices" + "strings" "sync" "time" @@ -16,6 +22,8 @@ import ( "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" + "github.com/netbirdio/netbird/proxy/internal/certwatch" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -23,7 +31,7 @@ import ( var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2} type certificateNotifier interface { - NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error + NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error } type domainState int @@ -35,12 +43,44 @@ const ( ) type domainInfo struct { - accountID string - serviceID string + accountID types.AccountID + serviceID types.ServiceID state domainState err string } +type metricsRecorder interface { + RecordCertificateIssuance(duration time.Duration) +} + +// wildcardEntry maps a domain suffix (e.g. ".example.com") to a certwatch +// watcher that hot-reloads the corresponding wildcard certificate from disk. +type wildcardEntry struct { + suffix string // e.g. ".example.com" + pattern string // e.g. "*.example.com" + watcher *certwatch.Watcher +} + +// ManagerConfig holds the configuration values for the ACME certificate manager. +type ManagerConfig struct { + // CertDir is the directory used for caching ACME certificates. + CertDir string + // ACMEURL is the ACME directory URL (e.g. Let's Encrypt). + ACMEURL string + // EABKID and EABHMACKey are optional External Account Binding credentials + // required by some CAs (e.g. ZeroSSL). EABHMACKey is the base64 + // URL-encoded string provided by the CA. + EABKID string + EABHMACKey string + // LockMethod controls the cross-replica coordination strategy. + LockMethod CertLockMethod + // WildcardDir is an optional path to a directory containing wildcard + // certificate pairs (.crt / .key). Wildcard patterns are + // extracted from the certificates' SAN lists. Domains matching a + // wildcard are served from disk; all others go through ACME. + WildcardDir string +} + // Manager wraps autocert.Manager with domain tracking and cross-replica // coordination via a pluggable locking strategy. The locker prevents // duplicate ACME requests when multiple replicas share a certificate cache. @@ -52,33 +92,182 @@ type Manager struct { mu sync.RWMutex domains map[domain.Domain]*domainInfo + // wildcards holds all loaded wildcard certificates, keyed by suffix. + wildcards []wildcardEntry + certNotifier certificateNotifier logger *log.Logger + metrics metricsRecorder } -// NewManager creates a new ACME certificate manager. The certDir is used -// for caching certificates. The lockMethod controls cross-replica -// coordination strategy (see CertLockMethod constants). -func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { +// NewManager creates a new ACME certificate manager. +func NewManager(cfg ManagerConfig, notifier certificateNotifier, logger *log.Logger, metrics metricsRecorder) (*Manager, error) { if logger == nil { logger = log.StandardLogger() } mgr := &Manager{ - certDir: certDir, - locker: newCertLocker(lockMethod, certDir, logger), + certDir: cfg.CertDir, + locker: newCertLocker(cfg.LockMethod, cfg.CertDir, logger), domains: make(map[domain.Domain]*domainInfo), certNotifier: notifier, logger: logger, + metrics: metrics, } + + if cfg.WildcardDir != "" { + entries, err := loadWildcardDir(cfg.WildcardDir, logger) + if err != nil { + return nil, fmt.Errorf("load wildcard certificates from %q: %w", cfg.WildcardDir, err) + } + mgr.wildcards = entries + } + + var eab *acme.ExternalAccountBinding + if cfg.EABKID != "" && cfg.EABHMACKey != "" { + decodedKey, err := base64.RawURLEncoding.DecodeString(cfg.EABHMACKey) + if err != nil { + logger.Errorf("failed to decode EAB HMAC key: %v", err) + } else { + eab = &acme.ExternalAccountBinding{ + KID: cfg.EABKID, + Key: decodedKey, + } + logger.Infof("configured External Account Binding with KID: %s", cfg.EABKID) + } + } + mgr.Manager = &autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: mgr.hostPolicy, - Cache: autocert.DirCache(certDir), + Prompt: autocert.AcceptTOS, + HostPolicy: mgr.hostPolicy, + Cache: autocert.DirCache(cfg.CertDir), + ExternalAccountBinding: eab, Client: &acme.Client{ - DirectoryURL: acmeURL, + DirectoryURL: cfg.ACMEURL, }, } - return mgr + return mgr, nil +} + +// WatchWildcards starts watching all wildcard certificate files for changes. +// It blocks until ctx is cancelled. It is a no-op if no wildcards are loaded. +func (mgr *Manager) WatchWildcards(ctx context.Context) { + if len(mgr.wildcards) == 0 { + return + } + seen := make(map[*certwatch.Watcher]struct{}) + var wg sync.WaitGroup + for i := range mgr.wildcards { + w := mgr.wildcards[i].watcher + if _, ok := seen[w]; ok { + continue + } + seen[w] = struct{}{} + wg.Add(1) + go func() { + defer wg.Done() + w.Watch(ctx) + }() + } + wg.Wait() +} + +// loadWildcardDir scans dir for .crt files, pairs each with a matching .key +// file, loads them, and extracts wildcard SANs (*.example.com) to build +// the suffix lookup entries. +func loadWildcardDir(dir string, logger *log.Logger) ([]wildcardEntry, error) { + crtFiles, err := filepath.Glob(filepath.Join(dir, "*.crt")) + if err != nil { + return nil, fmt.Errorf("glob certificate files: %w", err) + } + + if len(crtFiles) == 0 { + return nil, fmt.Errorf("no .crt files found in %s", dir) + } + + var entries []wildcardEntry + + for _, crtPath := range crtFiles { + base := strings.TrimSuffix(filepath.Base(crtPath), ".crt") + keyPath := filepath.Join(dir, base+".key") + if _, err := os.Stat(keyPath); err != nil { + logger.Warnf("skipping %s: no matching key file %s", crtPath, keyPath) + continue + } + + watcher, err := certwatch.NewWatcher(crtPath, keyPath, logger) + if err != nil { + logger.Warnf("skipping %s: %v", crtPath, err) + continue + } + + leaf := watcher.Leaf() + if leaf == nil { + logger.Warnf("skipping %s: no parsed leaf certificate", crtPath) + continue + } + + for _, san := range leaf.DNSNames { + suffix, ok := parseWildcard(san) + if !ok { + continue + } + entries = append(entries, wildcardEntry{ + suffix: suffix, + pattern: san, + watcher: watcher, + }) + logger.Infof("wildcard certificate loaded: %s (from %s)", san, filepath.Base(crtPath)) + } + } + + if len(entries) == 0 { + return nil, fmt.Errorf("no wildcard SANs (*.example.com) found in certificates in %s", dir) + } + + return entries, nil +} + +// parseWildcard validates a wildcard domain pattern like "*.example.com" +// and returns the suffix ".example.com" for matching. +func parseWildcard(pattern string) (suffix string, ok bool) { + if !strings.HasPrefix(pattern, "*.") { + return "", false + } + parent := pattern[1:] // ".example.com" + if strings.Count(parent, ".") < 1 { + return "", false + } + return strings.ToLower(parent), true +} + +// findWildcardEntry returns the wildcard entry that covers host, or nil. +func (mgr *Manager) findWildcardEntry(host string) *wildcardEntry { + if len(mgr.wildcards) == 0 { + return nil + } + host = strings.ToLower(host) + for i := range mgr.wildcards { + e := &mgr.wildcards[i] + if !strings.HasSuffix(host, e.suffix) { + continue + } + // Single-level match: prefix before suffix must have no dots. + prefix := strings.TrimSuffix(host, e.suffix) + if len(prefix) > 0 && !strings.Contains(prefix, ".") { + return e + } + } + return nil +} + +// WildcardPatterns returns the wildcard patterns that are currently loaded. +func (mgr *Manager) WildcardPatterns() []string { + patterns := make([]string, len(mgr.wildcards)) + for i, e := range mgr.wildcards { + patterns[i] = e.pattern + } + slices.Sort(patterns) + return patterns } func (mgr *Manager) hostPolicy(_ context.Context, host string) error { @@ -94,8 +283,39 @@ func (mgr *Manager) hostPolicy(_ context.Context, host string) error { return nil } -// AddDomain registers a domain for ACME certificate prefetching. -func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { +// GetCertificate returns the TLS certificate for the given ClientHello. +// If the requested domain matches a loaded wildcard, the static wildcard +// certificate is returned. Otherwise, the ACME autocert manager handles +// the request. +func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if e := mgr.findWildcardEntry(hello.ServerName); e != nil { + return e.watcher.GetCertificate(hello) + } + return mgr.Manager.GetCertificate(hello) +} + +// AddDomain registers a domain for certificate management. Domains that +// match a loaded wildcard are marked ready immediately (they use the +// static wildcard certificate) and the method returns true. All other +// domains go through ACME prefetch and the method returns false. +// +// When AddDomain returns true the caller is responsible for sending any +// certificate-ready notifications after the surrounding operation (e.g. +// mapping update) has committed successfully. +func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) { + name := d.PunycodeString() + if e := mgr.findWildcardEntry(name); e != nil { + mgr.mu.Lock() + mgr.domains[d] = &domainInfo{ + accountID: accountID, + serviceID: serviceID, + state: domainReady, + } + mgr.mu.Unlock() + mgr.logger.Debugf("domain %q matches wildcard %q, using static certificate", name, e.pattern) + return true + } + mgr.mu.Lock() mgr.domains[d] = &domainInfo{ accountID: accountID, @@ -105,13 +325,19 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { mgr.mu.Unlock() go mgr.prefetchCertificate(d) + return false } // prefetchCertificate proactively triggers certificate generation for a domain. // It acquires a distributed lock to prevent multiple replicas from issuing // duplicate ACME requests. The second replica will block until the first // finishes, then find the certificate in the cache. +// ACME and periodic disk reads race; whichever produces a valid certificate +// first wins. This handles cases where locking is unreliable and another +// replica already wrote the cert to the shared cache. func (mgr *Manager) prefetchCertificate(d domain.Domain) { + time.Sleep(time.Duration(rand.IntN(200)) * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -127,22 +353,105 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { defer unlock() } - hello := &tls.ClientHelloInfo{ - ServerName: name, - Conn: &dummyConn{ctx: ctx}, - } - - start := time.Now() - cert, err := mgr.GetCertificate(hello) - elapsed := time.Since(start) - if err != nil { - mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err) - mgr.setDomainState(d, domainFailed, err.Error()) + if cert, err := mgr.readCertFromDisk(ctx, name); err == nil { + mgr.logger.Infof("certificate for domain %q already on disk, skipping ACME", name) + mgr.recordAndNotify(ctx, d, name, cert, 0) return } - mgr.setDomainState(d, domainReady, "") + // Run ACME in a goroutine so we can race it against periodic disk reads. + // autocert uses its own internal context and cannot be cancelled externally. + type acmeResult struct { + cert *tls.Certificate + err error + } + acmeCh := make(chan acmeResult, 1) + hello := &tls.ClientHelloInfo{ServerName: name, Conn: &dummyConn{ctx: ctx}} + go func() { + cert, err := mgr.GetCertificate(hello) + acmeCh <- acmeResult{cert, err} + }() + start := time.Now() + diskTicker := time.NewTicker(5 * time.Second) + defer diskTicker.Stop() + + for { + select { + case res := <-acmeCh: + elapsed := time.Since(start) + if res.err != nil { + mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), res.err) + mgr.setDomainState(d, domainFailed, res.err.Error()) + return + } + mgr.recordAndNotify(ctx, d, name, res.cert, elapsed) + return + + case <-diskTicker.C: + cert, err := mgr.readCertFromDisk(context.Background(), name) + if err != nil { + continue + } + mgr.logger.Infof("certificate for domain %q appeared on disk after %s", name, time.Since(start).Round(time.Millisecond)) + // Drain the ACME goroutine before marking ready — autocert holds + // an internal write lock on certState while ACME is in flight. + go func() { + select { + case <-acmeCh: + default: + } + mgr.recordAndNotify(context.Background(), d, name, cert, 0) + }() + return + + case <-ctx.Done(): + mgr.logger.Warnf("prefetch certificate for domain %q timed out", name) + mgr.setDomainState(d, domainFailed, ctx.Err().Error()) + return + } + } +} + +// readCertFromDisk reads and parses a certificate directly from the autocert +// DirCache, bypassing autocert's internal certState mutex. Safe to call +// concurrently with an in-flight ACME request for the same domain. +func (mgr *Manager) readCertFromDisk(ctx context.Context, name string) (*tls.Certificate, error) { + if mgr.Cache == nil { + return nil, fmt.Errorf("no cache configured") + } + data, err := mgr.Cache.Get(ctx, name) + if err != nil { + return nil, err + } + privBlock, certsPEM := pem.Decode(data) + if privBlock == nil || !strings.Contains(privBlock.Type, "PRIVATE") { + return nil, fmt.Errorf("no private key in cache for %q", name) + } + cert, err := tls.X509KeyPair(certsPEM, pem.EncodeToMemory(privBlock)) + if err != nil { + return nil, fmt.Errorf("parse cached certificate for %q: %w", name, err) + } + if len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("parse leaf for %q: %w", name, err) + } + if time.Now().After(leaf.NotAfter) { + return nil, fmt.Errorf("cached certificate for %q expired at %s", name, leaf.NotAfter) + } + cert.Leaf = leaf + } + return &cert, nil +} + +// recordAndNotify records metrics, marks the domain ready, logs cert details, +// and notifies the cert notifier. +func (mgr *Manager) recordAndNotify(ctx context.Context, d domain.Domain, name string, cert *tls.Certificate, elapsed time.Duration) { + if elapsed > 0 && mgr.metrics != nil { + mgr.metrics.RecordCertificateIssuance(elapsed) + } + mgr.setDomainState(d, domainReady, "") now := time.Now() if cert != nil && cert.Leaf != nil { leaf := cert.Leaf @@ -158,11 +467,9 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { } else { mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond)) } - mgr.mu.RLock() info := mgr.domains[d] mgr.mu.RUnlock() - if info != nil && mgr.certNotifier != nil { if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil { mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err) diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 3b554e360..ceb9ca13a 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -2,16 +2,29 @@ package acme import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" ) func TestHostPolicy(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") - mgr.AddDomain("example.com", "acc1", "rp1") + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) + mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1")) // Wait for the background prefetch goroutine to finish so the temp dir // can be cleaned up without a race. @@ -70,7 +83,8 @@ func TestHostPolicy(t *testing.T) { } func TestDomainStates(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") @@ -80,8 +94,8 @@ func TestDomainStates(t *testing.T) { // AddDomain starts as pending, then the prefetch goroutine will fail // (no real ACME server) and transition to failed. - mgr.AddDomain("a.example.com", "acc1", "rp1") - mgr.AddDomain("b.example.com", "acc1", "rp1") + mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) + mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered") @@ -100,3 +114,193 @@ func TestDomainStates(t *testing.T) { assert.Contains(t, failed, "b.example.com") assert.Empty(t, mgr.ReadyDomains()) } + +func TestParseWildcard(t *testing.T) { + tests := []struct { + pattern string + wantSuffix string + wantOK bool + }{ + {"*.example.com", ".example.com", true}, + {"*.foo.example.com", ".foo.example.com", true}, + {"*.COM", ".com", true}, // single-label TLD + {"example.com", "", false}, // no wildcard prefix + {"*example.com", "", false}, // missing dot + {"**.example.com", "", false}, // double star + {"", "", false}, + } + + for _, tc := range tests { + t.Run(tc.pattern, func(t *testing.T) { + suffix, ok := parseWildcard(tc.pattern) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantSuffix, suffix) + } + }) + } +} + +func TestMatchesWildcard(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + tests := []struct { + host string + match bool + }{ + {"foo.example.com", true}, + {"bar.example.com", true}, + {"FOO.Example.COM", true}, // case insensitive + {"example.com", false}, // bare parent + {"sub.foo.example.com", false}, // multi-level + {"notexample.com", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.host, func(t *testing.T) { + assert.Equal(t, tc.match, mgr.findWildcardEntry(tc.host) != nil) + }) + } +} + +// generateSelfSignedCert creates a temporary self-signed certificate and key +// for testing purposes. The baseName controls the output filenames: +// .crt and .key. +func generateSelfSignedCert(t *testing.T, dir, baseName string, dnsNames ...string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: dnsNames[0]}, + DNSNames: dnsNames, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certFile, err := os.Create(filepath.Join(dir, baseName+".crt")) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + keyFile, err := os.Create(filepath.Join(dir, baseName+".key")) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) +} + +func TestWildcardAddDomainSkipsACME(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + // Add a wildcard-matching domain — should be immediately ready. + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending") + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + + // Add a non-wildcard domain — should go through ACME (pending then failed). + mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2")) + assert.Equal(t, 2, mgr.TotalDomains()) + + // Wait for the ACME prefetch to fail. + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + assert.Contains(t, mgr.FailedDomains(), "other.net") +} + +func TestWildcardGetCertificate(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + + // GetCertificate for a wildcard-matching domain should return the static cert. + cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + require.NotNil(t, cert) + assert.Contains(t, cert.Leaf.DNSNames, "*.example.com") +} + +func TestMultipleWildcards(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + generateSelfSignedCert(t, wcDir, "other", "*.other.org") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns()) + + // Both wildcards should resolve. + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2")) + + assert.Equal(t, 0, mgr.PendingCerts()) + assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains()) + + // GetCertificate routes to the correct cert. + cert1, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + assert.Contains(t, cert1.Leaf.DNSNames, "*.example.com") + + cert2, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "bar.other.org"}) + require.NoError(t, err) + assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org") + + // Non-matching domain falls through to ACME. + mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3")) + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + assert.Contains(t, mgr.FailedDomains(), "custom.net") +} + +func TestWildcardDirEmpty(t *testing.T) { + wcDir := t.TempDir() + // Empty directory — no .crt files. + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no .crt files found") +} + +func TestWildcardDirNonWildcardCert(t *testing.T) { + wcDir := t.TempDir() + // Certificate without a wildcard SAN. + generateSelfSignedCert(t, wcDir, "plain", "plain.example.com") + + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no wildcard SANs") +} + +func TestNoWildcardDir(t *testing.T) { + // Empty string means no wildcard dir — pure ACME mode. + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) + assert.Empty(t, mgr.WildcardPatterns()) +} diff --git a/proxy/internal/auth/header.go b/proxy/internal/auth/header.go new file mode 100644 index 000000000..194800a49 --- /dev/null +++ b/proxy/internal/auth/header.go @@ -0,0 +1,69 @@ +package auth + +import ( + "errors" + "fmt" + "net/http" + + "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// ErrHeaderAuthFailed indicates that the header was present but the +// credential did not validate. Callers should return 401 instead of +// falling through to other auth schemes. +var ErrHeaderAuthFailed = errors.New("header authentication failed") + +// Header implements header-based authentication. The proxy checks for the +// configured header in each request and validates its value via gRPC. +type Header struct { + id types.ServiceID + accountId types.AccountID + headerName string + client authenticator +} + +// NewHeader creates a Header authentication scheme for the given header name. +func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header { + return Header{ + id: id, + accountId: accountId, + headerName: headerName, + client: client, + } +} + +// Type returns auth.MethodHeader. +func (Header) Type() auth.Method { + return auth.MethodHeader +} + +// Authenticate checks for the configured header in the request. If absent, +// returns empty (unauthenticated). If present, validates via gRPC. +func (h Header) Authenticate(r *http.Request) (string, string, error) { + value := r.Header.Get(h.headerName) + if value == "" { + return "", "", nil + } + + res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ + Id: string(h.id), + AccountId: string(h.accountId), + Request: &proto.AuthenticateRequest_HeaderAuth{ + HeaderAuth: &proto.HeaderAuthRequest{ + HeaderValue: value, + HeaderName: h.headerName, + }, + }, + }) + if err != nil { + return "", "", fmt.Errorf("authenticate header: %w", err) + } + + if res.GetSuccess() { + return res.GetSessionToken(), "", nil + } + + return "", "", ErrHeaderAuthFailed +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 8a966faa3..055e4510f 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -4,9 +4,12 @@ import ( "context" "crypto/ed25519" "encoding/base64" + "errors" "fmt" + "html" "net" "net/http" + "net/netip" "net/url" "sync" "time" @@ -16,11 +19,16 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/proto" ) +// errValidationUnavailable indicates that session validation failed due to +// an infrastructure error (e.g. gRPC unavailable), not an invalid token. +var errValidationUnavailable = errors.New("session validation unavailable") + type authenticator interface { Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error) } @@ -40,12 +48,14 @@ type Scheme interface { Authenticate(*http.Request) (token string, promptData string, err error) } +// DomainConfig holds the authentication and restriction settings for a protected domain. type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration - AccountID string - ServiceID string + AccountID types.AccountID + ServiceID types.ServiceID + IPRestrictions *restrict.Filter } type validationResult struct { @@ -54,17 +64,18 @@ type validationResult struct { DeniedReason string } +// Middleware applies per-domain authentication and IP restriction checks. type Middleware struct { domainsMux sync.RWMutex domains map[string]DomainConfig logger *log.Logger sessionValidator SessionValidator + geo restrict.GeoResolver } -// NewMiddleware creates a new authentication middleware. -// The sessionValidator is optional; if nil, OIDC session tokens will be validated -// locally without group access checks. -func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware { +// NewMiddleware creates a new authentication middleware. The sessionValidator is +// optional; if nil, OIDC session tokens are validated locally without group access checks. +func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware { if logger == nil { logger = log.StandardLogger() } @@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl domains: make(map[string]DomainConfig), logger: logger, sessionValidator: sessionValidator, + geo: geo, } } -// Protect applies authentication middleware to the passed handler. -// For each incoming request it will be checked against the middleware's -// internal list of protected domains. -// If the Host domain in the inbound request is not present, then it will -// simply be passed through. -// However, if the Host domain is present, then the specified authentication -// schemes for that domain will be applied to the request. -// In the event that no authentication schemes are defined for the domain, -// then the request will also be simply passed through. +// Protect wraps next with per-domain authentication and IP restriction checks. +// Requests whose Host is not registered pass through unchanged. func (mw *Middleware) Protect(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.Host) @@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { config, exists := mw.getDomainConfig(host) mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists) - // Domains that are not configured here or have no authentication schemes applied should simply pass through. - if !exists || len(config.Schemes) == 0 { + if !exists { next.ServeHTTP(w, r) return } @@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { // Set account and service IDs in captured data for access logging. setCapturedIDs(r, config) + if !mw.checkIPRestrictions(w, r, config) { + return + } + + // Domains with no authentication schemes pass through after IP checks. + if len(config.Schemes) == 0 { + next.ServeHTTP(w, r) + return + } + if mw.handleOAuthCallbackError(w, r) { return } @@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + if mw.forwardWithHeaderAuth(w, r, host, config, next) { + return + } + mw.authenticateWithSchemes(w, r, host, config) }) } @@ -124,11 +142,79 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { func setCapturedIDs(r *http.Request, config DomainConfig) { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(types.AccountID(config.AccountID)) - cd.SetServiceId(config.ServiceID) + cd.SetAccountID(config.AccountID) + cd.SetServiceID(config.ServiceID) } } +// checkIPRestrictions validates the client IP against the domain's IP restrictions. +// Uses the resolved client IP from CapturedData (which accounts for trusted proxies) +// rather than r.RemoteAddr directly. +func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool { + if config.IPRestrictions == nil { + return true + } + + clientIP := mw.resolveClientIP(r) + if !clientIP.IsValid() { + mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr) + http.Error(w, "Forbidden", http.StatusForbidden) + return false + } + + verdict := config.IPRestrictions.Check(clientIP, mw.geo) + if verdict == restrict.Allow { + return true + } + + if verdict.IsCrowdSec() { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetMetadata("crowdsec_verdict", verdict.String()) + if config.IPRestrictions.IsObserveOnly(verdict) { + cd.SetMetadata("crowdsec_mode", "observe") + } + } + } + + if config.IPRestrictions.IsObserveOnly(verdict) { + mw.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", clientIP, r.Host, verdict) + return true + } + + reason := verdict.String() + mw.blockIPRestriction(r, reason) + http.Error(w, "Forbidden", http.StatusForbidden) + return false +} + +// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr. +func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + if ip := cd.GetClientIP(); ip.IsValid() { + return ip + } + } + + clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr) + if clientIPStr == "" { + clientIPStr = r.RemoteAddr + } + addr, err := netip.ParseAddr(clientIPStr) + if err != nil { + return netip.Addr{} + } + return addr.Unmap() +} + +// blockIPRestriction sets captured data fields for an IP-restriction block event. +func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetAuthMethod(reason) + } + mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr) +} + // handleOAuthCallbackError checks for error query parameters from an OAuth // callback and renders the access denied page if present. func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool { @@ -146,6 +232,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re errDesc := r.URL.Query().Get("error_description") if errDesc == "" { errDesc = "An error occurred during authentication" + } else { + errDesc = html.EscapeString(errDesc) } web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID) return true @@ -170,6 +258,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re return true } +// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates, +// the request is forwarded directly (no redirect), which is important for API clients. +func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool { + for _, scheme := range config.Schemes { + hdr, ok := scheme.(Header) + if !ok { + continue + } + + handled := mw.tryHeaderScheme(w, r, host, config, hdr, next) + if handled { + return true + } + } + return false +} + +func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool { + token, _, err := hdr.Authenticate(r) + if err != nil { + return mw.handleHeaderAuthError(w, r, err) + } + if token == "" { + return false + } + + result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader) + if err != nil { + setHeaderCapturedData(r.Context(), "") + status := http.StatusBadRequest + msg := "invalid session token" + if errors.Is(err, errValidationUnavailable) { + status = http.StatusBadGateway + msg = "authentication service unavailable" + } + http.Error(w, msg, status) + return true + } + + if !result.Valid { + setHeaderCapturedData(r.Context(), result.UserID) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return true + } + + setSessionCookie(w, token, config.SessionExpiration) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetUserID(result.UserID) + cd.SetAuthMethod(auth.MethodHeader.String()) + } + + next.ServeHTTP(w, r) + return true +} + +func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool { + if errors.Is(err, ErrHeaderAuthFailed) { + setHeaderCapturedData(r.Context(), "") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return true + } + mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + } + http.Error(w, "authentication service unavailable", http.StatusBadGateway) + return true +} + +func setHeaderCapturedData(ctx context.Context, userID string) { + cd := proxy.CapturedDataFromContext(ctx) + if cd == nil { + return + } + cd.SetOrigin(proxy.OriginAuth) + cd.SetAuthMethod(auth.MethodHeader.String()) + cd.SetUserID(userID) +} + // authenticateWithSchemes tries each configured auth scheme in order. // On success it sets a session cookie and redirects; on failure it renders the login page. func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) { @@ -205,6 +372,12 @@ func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Req cd.SetAuthMethod(attemptedMethod) } } + + if oidcURL, ok := methods[auth.MethodOIDC.String()]; ok && len(methods) == 1 && oidcURL != "" { + http.Redirect(w, r, oidcURL, http.StatusFound) + return + } + web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized) } @@ -217,7 +390,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re cd.SetOrigin(proxy.OriginAuth) cd.SetAuthMethod(scheme.Type().String()) } - http.Error(w, err.Error(), http.StatusBadRequest) + status := http.StatusBadRequest + msg := "invalid session token" + if errors.Is(err, errValidationUnavailable) { + status = http.StatusBadGateway + msg = "authentication service unavailable" + } + http.Error(w, msg, status) return } @@ -233,7 +412,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re return } - expiration := config.SessionExpiration + setSessionCookie(w, token, config.SessionExpiration) + + // Redirect instead of forwarding the auth POST to the backend. + // The browser will follow with a GET carrying the new session cookie. + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) + } + redirectURL := stripSessionTokenParam(r.URL) + http.Redirect(w, r, redirectURL, http.StatusSeeOther) +} + +// setSessionCookie writes a session cookie with secure defaults. +func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) { if expiration == 0 { expiration = auth.DefaultSessionExpiry } @@ -245,16 +438,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re SameSite: http.SameSiteLaxMode, MaxAge: int(expiration.Seconds()), }) - - // Redirect instead of forwarding the auth POST to the backend. - // The browser will follow with a GET carrying the new session cookie. - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetUserID(result.UserID) - cd.SetAuthMethod(scheme.Type().String()) - } - redirectURL := stripSessionTokenParam(r.URL) - http.Redirect(w, r, redirectURL, http.StatusSeeOther) } // wasCredentialSubmitted checks if credentials were submitted for the given auth method. @@ -275,13 +458,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() mw.domains[domain] = DomainConfig{ - AccountID: accountID, - ServiceID: serviceID, + AccountID: accountID, + ServiceID: serviceID, + IPRestrictions: ipRestrictions, } return nil } @@ -302,30 +486,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st SessionExpiration: expiration, AccountID: accountID, ServiceID: serviceID, + IPRestrictions: ipRestrictions, } return nil } +// RemoveDomain unregisters authentication for the given domain. func (mw *Middleware) RemoveDomain(domain string) { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() delete(mw.domains, domain) } -// validateSessionToken validates a session token, optionally checking group access via gRPC. -// For OIDC tokens with a configured validator, it calls ValidateSession to check group access. -// For other auth methods (PIN, password), it validates the JWT locally. -// Returns a validationResult with user ID and validity status, or error for invalid tokens. +// validateSessionToken validates a session token. OIDC tokens with a configured +// validator go through gRPC for group access checks; other methods validate locally. func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) { - // For OIDC with a session validator, call the gRPC service to check group access if method == auth.MethodOIDC && mw.sessionValidator != nil { resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{ Domain: host, SessionToken: token, }) if err != nil { - mw.logger.WithError(err).Error("ValidateSession gRPC call failed") - return nil, fmt.Errorf("session validation failed") + return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err) } if !resp.Valid { mw.logger.WithFields(log.Fields{ @@ -342,7 +524,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri return &validationResult{UserID: resp.UserId, Valid: true}, nil } - // For non-OIDC methods or when no validator is configured, validate JWT locally userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) if err != nil { return nil, err diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 7d9ac1bd5..16d09800c 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -1,11 +1,14 @@ package auth import ( + "context" "crypto/ed25519" "crypto/rand" "encoding/base64" + "errors" "net/http" "net/http/httptest" + "net/netip" "net/url" "strings" "testing" @@ -14,10 +17,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/shared/management/proto" ) func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { @@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler { } func TestAddDomain_ValidKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil) require.NoError(t, err) mw.domainsMux.RLock() @@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) { } func TestAddDomain_EmptyKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) { } func TestAddDomain_InvalidBase64(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "decode session public key") @@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) { } func TestAddDomain_WrongKeySize(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) { } func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) - err := mw.AddDomain("example.com", nil, "", time.Hour, "", "") + err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil) require.NoError(t, err, "domains with no auth schemes should not require a key") mw.domainsMux.RLock() @@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { } func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil)) mw.domainsMux.RLock() config := mw.domains["example.com"] @@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { } func TestRemoveDomain(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) mw.RemoveDomain("example.com") @@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) { } func TestProtect_UnknownDomainPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) handler := mw.Protect(newPassthroughHandler()) req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil) @@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) { } func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) - require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "")) + mw := NewMiddleware(log.StandardLogger(), nil, nil) + require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { } func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { } func TestProtect_HostWithPortIsMatched(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) { } func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) require.NoError(t, err) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cd := proxy.CapturedDataFromContext(r.Context()) require.NotNil(t, cd) @@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { } func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Sign a token that expired 1 second ago. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) @@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { } func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Token signed for a different domain audience. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) @@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { } func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) // Token signed with a different private key. token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) @@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { } func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) @@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { } func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { } func TestProtect_MultipleSchemes(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) @@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { } func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) // Return a garbage token that won't validate. @@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { return "invalid-jwt-token", "", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { } func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) // 32 random bytes that happen to be valid base64 and correct size // but are actually a valid ed25519 public key length-wise. @@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { key := base64.StdEncoding.EncodeToString(randomBytes) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "") + err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil) require.NoError(t, err, "any 32-byte key should be accepted at registration time") } func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Attempt to overwrite with an invalid key. - err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil) require.Error(t, err) // The original valid config should still be intact. @@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { } func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) // Scheme that always fails authentication (returns empty token) @@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // Submit wrong PIN - should capture auth method @@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { } func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // Submit wrong password - should capture auth method @@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { } func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // No credentials submitted - should not capture auth method @@ -658,3 +664,389 @@ func TestWasCredentialSubmitted(t *testing.T) { }) } } + +func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + remoteAddr string + wantCode int + }{ + {"unparsable address denies", "not-an-ip:1234", http.StatusForbidden}, + {"empty address denies", "", http.StatusForbidden}, + {"allowed address passes", "10.1.2.3:5678", http.StatusOK}, + {"denied address blocked", "192.168.1.1:5678", http.StatusForbidden}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = tt.remoteAddr + req.Host = "example.com" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } +} + +func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) { + // When CapturedData is set (by the access log middleware, which resolves + // trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr. + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}})) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // RemoteAddr is a trusted proxy, but CapturedData has the real client IP. + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "10.0.0.1:5000" + req.Host = "example.com" + + cd := proxy.NewCapturedData("") + cd.SetClientIP(netip.MustParseAddr("203.0.113.50")) + ctx := proxy.WithCapturedData(req.Context(), cd) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)") + + // Same request but CapturedData has a blocked IP. + req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req2.RemoteAddr = "203.0.113.50:5000" + req2.Host = "example.com" + + cd2 := proxy.NewCapturedData("") + cd2.SetClientIP(netip.MustParseAddr("10.0.0.1")) + ctx2 := proxy.WithCapturedData(req2.Context(), cd2) + req2 = req2.WithContext(ctx2) + + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)") +} + +func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { + // Geo is nil, country restrictions are configured: must deny (fail-close). + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}})) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "1.2.3.4:5678" + req.Host = "example.com" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny") +} + +func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + oidcURL := "https://idp.example.com/authorize?client_id=abc" + scheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", oidcURL, nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code, "should redirect directly to IdP") + assert.Equal(t, oidcURL, rec.Header().Get("Location")) +} + +func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + oidcScheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", "https://idp.example.com/authorize", nil + }, + } + pinScheme := &stubScheme{ + method: auth.MethodPIN, + authFn: func(_ *http.Request) (string, string, error) { + return "", "pin", nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, "should show login page when multiple methods exist") +} + +// mockAuthenticator is a minimal mock for the authenticator gRPC interface +// used by the Header scheme. +type mockAuthenticator struct { + fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) +} + +func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) { + return m.fn(ctx, in) +} + +// newHeaderSchemeWithToken creates a Header scheme backed by a mock that +// returns a signed session token when the expected header value is provided. +func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header { + t.Helper() + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + require.NoError(t, err) + + mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + ha := req.GetHeaderAuth() + if ha != nil && ha.GetHeaderValue() == expectedValue { + return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil + } + return &proto.AuthenticateResponse{Success: false}, nil + }} + return NewHeader(mock, "svc1", "acc1", headerName) +} + +func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + var backendCalled bool + capturedData := proxy.NewCapturedData("") + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil) + req.Header.Set("X-API-Key", "secret-key") + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + + // Session cookie should be set. + var sessionCookie *http.Cookie + for _, c := range rec.Result().Cookies() { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth") + assert.True(t, sessionCookie.HttpOnly) + assert.True(t, sessionCookie.Secure) + + assert.Equal(t, "header-user", capturedData.GetUserID()) + assert.Equal(t, "header", capturedData.GetAuthMethod()) +} + +func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + // Also add a PIN scheme so we can verify fallthrough behavior. + pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + // No X-API-Key header: should fall through to PIN login page (401). + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page") +} + +func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + return &proto.AuthenticateResponse{Success: false}, nil + }} + hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + capturedData := proxy.NewCapturedData("") + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("X-API-Key", "wrong-key") + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.Equal(t, "header", capturedData.GetAuthMethod()) +} + +func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + return nil, errors.New("gRPC unavailable") + }} + hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("X-API-Key", "some-key") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadGateway, rec.Code) +} + +func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request with header auth. + req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req1.Header.Set("X-API-Key", "secret-key") + req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData(""))) + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + + // Extract session cookie. + var sessionCookie *http.Cookie + for _, c := range rec1.Result().Cookies() { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + require.NotNil(t, sessionCookie) + + // Second request with only the session cookie (no header). + capturedData2 := proxy.NewCapturedData("") + req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil) + req2.AddCookie(sessionCookie) + req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2)) + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + assert.Equal(t, http.StatusOK, rec2.Code) + assert.Equal(t, "header-user", capturedData2.GetUserID()) + assert.Equal(t, "header", capturedData2.GetAuthMethod()) +} + +// TestProtect_HeaderAuth_MultipleValuesSameHeader verifies that the proxy +// correctly handles multiple valid credentials for the same header name. +// In production, the mgmt gRPC authenticateHeader iterates all configured +// header auths and accepts if any hash matches (OR semantics). The proxy +// creates one Header scheme per entry, but a single gRPC call checks all. +func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + // Mock simulates mgmt behavior: accepts either token-a or token-b. + accepted := map[string]bool{"Bearer token-a": true, "Bearer token-b": true} + mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + ha := req.GetHeaderAuth() + if ha != nil && accepted[ha.GetHeaderValue()] { + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + require.NoError(t, err) + return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil + } + return &proto.AuthenticateResponse{Success: false}, nil + }} + + // Single Header scheme (as if one entry existed), but the mock checks both values. + hdr := NewHeader(mock, "svc1", "acc1", "Authorization") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + var backendCalled bool + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + })) + + t.Run("first value accepted", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-a") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, backendCalled, "first token should be accepted") + }) + + t.Run("second value accepted", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-b") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, backendCalled, "second token should be accepted") + }) + + t.Run("unknown value rejected", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-c") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.False(t, backendCalled, "unknown token should be rejected") + }) +} diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index bf178d432..a60e6437a 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -17,14 +18,14 @@ type urlGenerator interface { } type OIDC struct { - id string - accountId string + id types.ServiceID + accountId types.AccountID forwardedProto string client urlGenerator } // NewOIDC creates a new OIDC authentication scheme -func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC { +func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC { return OIDC{ id: id, accountId: accountId, @@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) { } res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{ - Id: o.id, - AccountId: o.accountId, + Id: string(o.id), + AccountId: string(o.accountId), RedirectUrl: redirectURL.String(), }) if err != nil { diff --git a/proxy/internal/auth/password.go b/proxy/internal/auth/password.go index 208423465..6a7eda3e1 100644 --- a/proxy/internal/auth/password.go +++ b/proxy/internal/auth/password.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const passwordFormId = "password" type Password struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPassword(client authenticator, id, accountId string) Password { +func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password { return Password{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Password{ Password: &proto.PasswordRequest{ Password: password, diff --git a/proxy/internal/auth/pin.go b/proxy/internal/auth/pin.go index c1eb56071..4d08f3dc6 100644 --- a/proxy/internal/auth/pin.go +++ b/proxy/internal/auth/pin.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const pinFormId = "pin" type Pin struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPin(client authenticator, id, accountId string) Pin { +func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin { return Pin{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Pin{ Pin: &proto.PinRequest{ Pin: pin, diff --git a/proxy/internal/certwatch/watcher.go b/proxy/internal/certwatch/watcher.go index 78ad1ab7c..6366a53c6 100644 --- a/proxy/internal/certwatch/watcher.go +++ b/proxy/internal/certwatch/watcher.go @@ -67,6 +67,13 @@ func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, erro return w.cert, nil } +// Leaf returns the parsed leaf certificate, or nil if not yet loaded. +func (w *Watcher) Leaf() *x509.Certificate { + w.mu.RLock() + defer w.mu.RUnlock() + return w.leaf +} + // Watch starts watching for certificate file changes. It blocks until // ctx is cancelled. It uses fsnotify for immediate detection and falls // back to polling if fsnotify is unavailable (e.g. on NFS). diff --git a/proxy/internal/conntrack/conn.go b/proxy/internal/conntrack/conn.go index 97055d992..8446d638f 100644 --- a/proxy/internal/conntrack/conn.go +++ b/proxy/internal/conntrack/conn.go @@ -10,10 +10,11 @@ import ( type trackedConn struct { net.Conn tracker *HijackTracker + host string } func (c *trackedConn) Close() error { - c.tracker.conns.Delete(c) + c.tracker.remove(c) return c.Conn.Close() } @@ -22,6 +23,7 @@ func (c *trackedConn) Close() error { type trackingWriter struct { http.ResponseWriter tracker *HijackTracker + host string } func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - tc := &trackedConn{Conn: conn, tracker: w.tracker} - w.tracker.conns.Store(tc, struct{}{}) + tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host} + w.tracker.add(tc) return tc, buf, nil } diff --git a/proxy/internal/conntrack/hijacked.go b/proxy/internal/conntrack/hijacked.go index d76cebc08..911f93f3d 100644 --- a/proxy/internal/conntrack/hijacked.go +++ b/proxy/internal/conntrack/hijacked.go @@ -1,7 +1,6 @@ package conntrack import ( - "net" "net/http" "sync" ) @@ -10,10 +9,14 @@ import ( // upgrades). http.Server.Shutdown does not close hijacked connections, so // they must be tracked and closed explicitly during graceful shutdown. // +// Connections are indexed by the request Host so they can be closed +// per-domain when a service mapping is removed. +// // Use Middleware as the outermost HTTP middleware to ensure hijacked // connections are tracked and automatically deregistered when closed. type HijackTracker struct { - conns sync.Map // net.Conn → struct{} + mu sync.Mutex + conns map[*trackedConn]struct{} } // Middleware returns an HTTP middleware that wraps the ResponseWriter so that @@ -21,21 +24,73 @@ type HijackTracker struct { // tracker when closed. This should be the outermost middleware in the chain. func (t *HijackTracker) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r) + next.ServeHTTP(&trackingWriter{ + ResponseWriter: w, + tracker: t, + host: hostOnly(r.Host), + }, r) }) } -// CloseAll closes all tracked hijacked connections and returns the number -// of connections that were closed. +// CloseAll closes all tracked hijacked connections and returns the count. func (t *HijackTracker) CloseAll() int { - var count int - t.conns.Range(func(key, _ any) bool { - if conn, ok := key.(net.Conn); ok { - _ = conn.Close() - count++ - } - t.conns.Delete(key) - return true - }) - return count + t.mu.Lock() + conns := t.conns + t.conns = nil + t.mu.Unlock() + + for tc := range conns { + _ = tc.Conn.Close() + } + return len(conns) +} + +// CloseByHost closes all tracked hijacked connections for the given host +// and returns the number of connections closed. +func (t *HijackTracker) CloseByHost(host string) int { + host = hostOnly(host) + t.mu.Lock() + var toClose []*trackedConn + for tc := range t.conns { + if tc.host == host { + toClose = append(toClose, tc) + } + } + for _, tc := range toClose { + delete(t.conns, tc) + } + t.mu.Unlock() + + for _, tc := range toClose { + _ = tc.Conn.Close() + } + return len(toClose) +} + +func (t *HijackTracker) add(tc *trackedConn) { + t.mu.Lock() + if t.conns == nil { + t.conns = make(map[*trackedConn]struct{}) + } + t.conns[tc] = struct{}{} + t.mu.Unlock() +} + +func (t *HijackTracker) remove(tc *trackedConn) { + t.mu.Lock() + delete(t.conns, tc) + t.mu.Unlock() +} + +// hostOnly strips the port from a host:port string. +func hostOnly(hostport string) string { + for i := len(hostport) - 1; i >= 0; i-- { + if hostport[i] == ':' { + return hostport[:i] + } + if hostport[i] < '0' || hostport[i] > '9' { + return hostport + } + } + return hostport } diff --git a/proxy/internal/conntrack/hijacked_test.go b/proxy/internal/conntrack/hijacked_test.go new file mode 100644 index 000000000..9ceefff78 --- /dev/null +++ b/proxy/internal/conntrack/hijacked_test.go @@ -0,0 +1,142 @@ +package conntrack + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing. +type fakeHijackWriter struct { + http.ResponseWriter + conn net.Conn +} + +func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} + +func TestCloseByHost(t *testing.T) { + var tracker HijackTracker + + // Simulate hijacking two connections for different hosts. + connA1, connA2 := net.Pipe() + defer connA2.Close() + connB1, connB2 := net.Pipe() + defer connB2.Close() + + twA := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "a.example.com", + } + twB := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "b.example.com", + } + + // Use fakeHijackWriter to provide the Hijack method. + twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1} + twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1} + + _, _, err := twA.Hijack() + require.NoError(t, err) + _, _, err = twB.Hijack() + require.NoError(t, err) + + tracker.mu.Lock() + assert.Equal(t, 2, len(tracker.conns), "should track 2 connections") + tracker.mu.Unlock() + + // Close only host A. + n := tracker.CloseByHost("a.example.com") + assert.Equal(t, 1, n, "should close 1 connection for host A") + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection") + tracker.mu.Unlock() + + // Verify host A's conn is actually closed. + buf := make([]byte, 1) + _, err = connA2.Read(buf) + assert.Error(t, err, "host A pipe should be closed") + + // Host B should still be alive. + go func() { _, _ = connB1.Write([]byte("x")) }() + + // Close all remaining. + n = tracker.CloseAll() + assert.Equal(t, 1, n, "should close remaining 1 connection") + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll") + tracker.mu.Unlock() +} + +func TestCloseAll(t *testing.T) { + var tracker HijackTracker + + for range 5 { + c1, c2 := net.Pipe() + defer c2.Close() + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"} + tracker.add(tc) + } + + tracker.mu.Lock() + assert.Equal(t, 5, len(tracker.conns)) + tracker.mu.Unlock() + + n := tracker.CloseAll() + assert.Equal(t, 5, n) + + // Double CloseAll is safe. + n = tracker.CloseAll() + assert.Equal(t, 0, n) +} + +func TestTrackedConn_AutoDeregister(t *testing.T) { + var tracker HijackTracker + + c1, c2 := net.Pipe() + defer c2.Close() + + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"} + tracker.add(tc) + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns)) + tracker.mu.Unlock() + + // Close the tracked conn: should auto-deregister. + require.NoError(t, tc.Close()) + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close") + tracker.mu.Unlock() +} + +func TestHostOnly(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com:443", "example.com"}, + {"example.com", "example.com"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:443", "[::1]"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, hostOnly(tt.input)) + }) + } +} diff --git a/proxy/internal/crowdsec/bouncer.go b/proxy/internal/crowdsec/bouncer.go new file mode 100644 index 000000000..06a452520 --- /dev/null +++ b/proxy/internal/crowdsec/bouncer.go @@ -0,0 +1,251 @@ +// Package crowdsec provides a CrowdSec stream bouncer that maintains a local +// decision cache for IP reputation checks. +package crowdsec + +import ( + "context" + "errors" + "net/netip" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + csbouncer "github.com/crowdsecurity/go-cs-bouncer" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/restrict" +) + +// Bouncer wraps a CrowdSec StreamBouncer, maintaining a local cache of +// active decisions for fast IP lookups. It implements restrict.CrowdSecChecker. +type Bouncer struct { + mu sync.RWMutex + ips map[netip.Addr]*restrict.CrowdSecDecision + prefixes map[netip.Prefix]*restrict.CrowdSecDecision + ready atomic.Bool + + apiURL string + apiKey string + tickerInterval time.Duration + logger *log.Entry + + // lifeMu protects cancel and done from concurrent Start/Stop calls. + lifeMu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +// compile-time check +var _ restrict.CrowdSecChecker = (*Bouncer)(nil) + +// NewBouncer creates a bouncer but does not start the stream. +func NewBouncer(apiURL, apiKey string, logger *log.Entry) *Bouncer { + return &Bouncer{ + apiURL: apiURL, + apiKey: apiKey, + logger: logger, + ips: make(map[netip.Addr]*restrict.CrowdSecDecision), + prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision), + } +} + +// Start launches the background goroutine that streams decisions from the +// CrowdSec LAPI. The stream runs until Stop is called or ctx is cancelled. +func (b *Bouncer) Start(ctx context.Context) error { + interval := b.tickerInterval + if interval == 0 { + interval = 10 * time.Second + } + stream := &csbouncer.StreamBouncer{ + APIKey: b.apiKey, + APIUrl: b.apiURL, + TickerInterval: interval.String(), + UserAgent: "netbird-proxy/1.0", + Scopes: []string{"ip", "range"}, + RetryInitialConnect: true, + } + + b.logger.Infof("connecting to CrowdSec LAPI at %s", b.apiURL) + + if err := stream.Init(); err != nil { + return err + } + + // Reset state from any previous run. + b.mu.Lock() + b.ips = make(map[netip.Addr]*restrict.CrowdSecDecision) + b.prefixes = make(map[netip.Prefix]*restrict.CrowdSecDecision) + b.mu.Unlock() + b.ready.Store(false) + + ctx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + + b.lifeMu.Lock() + if b.cancel != nil { + b.lifeMu.Unlock() + cancel() + return errors.New("bouncer already started") + } + b.cancel = cancel + b.done = done + b.lifeMu.Unlock() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + if err := stream.Run(ctx); err != nil && ctx.Err() == nil { + b.logger.Errorf("CrowdSec stream ended: %v", err) + } + }() + + go func() { + defer wg.Done() + b.consumeStream(ctx, stream) + }() + + go func() { + wg.Wait() + close(done) + }() + + return nil +} + +// Stop cancels the stream and waits for all goroutines to finish. +func (b *Bouncer) Stop() { + b.lifeMu.Lock() + cancel := b.cancel + done := b.done + b.cancel = nil + b.lifeMu.Unlock() + + if cancel != nil { + cancel() + <-done + } +} + +// Ready returns true after the first batch of decisions has been processed. +func (b *Bouncer) Ready() bool { + return b.ready.Load() +} + +// CheckIP looks up addr in the local decision cache. Returns nil if no +// active decision exists for the address. +// +// Prefix lookups are O(1): instead of scanning all stored prefixes, we +// probe the map for every possible containing prefix of the address +// (at most 33 for IPv4, 129 for IPv6). +func (b *Bouncer) CheckIP(addr netip.Addr) *restrict.CrowdSecDecision { + addr = addr.Unmap() + + b.mu.RLock() + defer b.mu.RUnlock() + + if d, ok := b.ips[addr]; ok { + return d + } + + maxBits := 32 + if addr.Is6() { + maxBits = 128 + } + // Walk from most-specific to least-specific prefix so the narrowest + // matching decision wins when ranges overlap. + for bits := maxBits; bits >= 0; bits-- { + prefix := netip.PrefixFrom(addr, bits).Masked() + if d, ok := b.prefixes[prefix]; ok { + return d + } + } + + return nil +} + +func (b *Bouncer) consumeStream(ctx context.Context, stream *csbouncer.StreamBouncer) { + first := true + for { + select { + case <-ctx.Done(): + return + case resp, ok := <-stream.Stream: + if !ok { + return + } + b.mu.Lock() + b.applyDeleted(resp.Deleted) + b.applyNew(resp.New) + b.mu.Unlock() + + if first { + b.ready.Store(true) + b.logger.Info("CrowdSec bouncer synced initial decisions") + first = false + } + } + } +} + +func (b *Bouncer) applyDeleted(decisions []*models.Decision) { + for _, d := range decisions { + if d.Value == nil || d.Scope == nil { + continue + } + value := *d.Value + + if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") { + prefix, err := netip.ParsePrefix(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec range deletion %q: %v", value, err) + continue + } + prefix = normalizePrefix(prefix) + delete(b.prefixes, prefix) + } else { + addr, err := netip.ParseAddr(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec IP deletion %q: %v", value, err) + continue + } + delete(b.ips, addr.Unmap()) + } + } +} + +func (b *Bouncer) applyNew(decisions []*models.Decision) { + for _, d := range decisions { + if d.Value == nil || d.Type == nil || d.Scope == nil { + continue + } + dec := &restrict.CrowdSecDecision{Type: restrict.DecisionType(*d.Type)} + value := *d.Value + + if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") { + prefix, err := netip.ParsePrefix(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec range %q: %v", value, err) + continue + } + prefix = normalizePrefix(prefix) + b.prefixes[prefix] = dec + } else { + addr, err := netip.ParseAddr(value) + if err != nil { + b.logger.Debugf("skip unparsable CrowdSec IP %q: %v", value, err) + continue + } + b.ips[addr.Unmap()] = dec + } + } +} + +// normalizePrefix unmaps v4-mapped-v6 addresses and zeros host bits so +// the prefix is a valid map key that matches CheckIP's probe logic. +func normalizePrefix(p netip.Prefix) netip.Prefix { + return netip.PrefixFrom(p.Addr().Unmap(), p.Bits()).Masked() +} diff --git a/proxy/internal/crowdsec/bouncer_test.go b/proxy/internal/crowdsec/bouncer_test.go new file mode 100644 index 000000000..3bd8aa068 --- /dev/null +++ b/proxy/internal/crowdsec/bouncer_test.go @@ -0,0 +1,337 @@ +package crowdsec + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "sync" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/restrict" +) + +func TestBouncer_CheckIP_Empty(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4"))) +} + +func TestBouncer_CheckIP_ExactMatch(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("10.0.0.2"))) +} + +func TestBouncer_CheckIP_PrefixMatch(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.prefixes[netip.MustParsePrefix("192.168.1.0/24")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("192.168.1.100")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("192.168.2.1"))) +} + +func TestBouncer_CheckIP_UnmapsV4InV6(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("::ffff:10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionBan, d.Type) +} + +func TestBouncer_Ready(t *testing.T) { + b := newTestBouncer() + assert.False(t, b.Ready()) + + b.ready.Store(true) + assert.True(t, b.Ready()) +} + +func TestBouncer_CheckIP_ExactBeforePrefix(t *testing.T) { + b := newTestBouncer() + b.ready.Store(true) + b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionCaptcha} + b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + d := b.CheckIP(netip.MustParseAddr("10.0.0.1")) + require.NotNil(t, d) + assert.Equal(t, restrict.DecisionCaptcha, d.Type) + + d2 := b.CheckIP(netip.MustParseAddr("10.0.0.2")) + require.NotNil(t, d2) + assert.Equal(t, restrict.DecisionBan, d2.Type) +} + +func TestBouncer_ApplyNew_IP(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "test/brute"}, + decision{scope: "ip", value: "5.6.7.8", dtype: "captcha", scenario: "test/crawl"}, + )) + + require.Len(t, b.ips, 2) + assert.Equal(t, restrict.DecisionBan, b.ips[netip.MustParseAddr("1.2.3.4")].Type) + assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("5.6.7.8")].Type) +} + +func TestBouncer_ApplyNew_Range(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"}, + )) + + require.Len(t, b.prefixes, 1) + assert.NotNil(t, b.prefixes[netip.MustParsePrefix("10.0.0.0/8")]) +} + +func TestBouncer_ApplyDeleted_IP(t *testing.T) { + b := newTestBouncer() + b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + b.ips[netip.MustParseAddr("5.6.7.8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyDeleted(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban"}, + )) + + assert.Len(t, b.ips, 1) + assert.Nil(t, b.ips[netip.MustParseAddr("1.2.3.4")]) + assert.NotNil(t, b.ips[netip.MustParseAddr("5.6.7.8")]) +} + +func TestBouncer_ApplyDeleted_Range(t *testing.T) { + b := newTestBouncer() + b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + b.prefixes[netip.MustParsePrefix("192.168.0.0/16")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyDeleted(makeDecisions( + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"}, + )) + + require.Len(t, b.prefixes, 1) + assert.NotNil(t, b.prefixes[netip.MustParsePrefix("192.168.0.0/16")]) +} + +func TestBouncer_ApplyNew_OverwritesExisting(t *testing.T) { + b := newTestBouncer() + b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan} + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "captcha"}, + )) + + assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("1.2.3.4")].Type) +} + +func TestBouncer_ApplyNew_SkipsInvalid(t *testing.T) { + b := newTestBouncer() + + b.applyNew(makeDecisions( + decision{scope: "ip", value: "not-an-ip", dtype: "ban"}, + decision{scope: "range", value: "also-not-valid", dtype: "ban"}, + )) + + assert.Empty(t, b.ips) + assert.Empty(t, b.prefixes) +} + +// TestBouncer_StreamIntegration tests the full flow: fake LAPI → StreamBouncer → Bouncer cache → CheckIP. +func TestBouncer_StreamIntegration(t *testing.T) { + lapi := newFakeLAPI() + ts := httptest.NewServer(lapi) + defer ts.Close() + + // Seed the LAPI with initial decisions. + lapi.setDecisions( + decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "crowdsecurity/ssh-bf"}, + decision{scope: "range", value: "10.0.0.0/8", dtype: "ban", scenario: "crowdsecurity/http-probing"}, + decision{scope: "ip", value: "5.5.5.5", dtype: "captcha", scenario: "crowdsecurity/http-crawl"}, + ) + + b := NewBouncer(ts.URL, "test-key", log.NewEntry(log.StandardLogger())) + b.tickerInterval = 200 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, b.Start(ctx)) + defer b.Stop() + + // Wait for initial sync. + require.Eventually(t, b.Ready, 5*time.Second, 50*time.Millisecond, "bouncer should become ready") + + // Verify decisions are cached. + d := b.CheckIP(netip.MustParseAddr("1.2.3.4")) + require.NotNil(t, d, "1.2.3.4 should be banned") + assert.Equal(t, restrict.DecisionBan, d.Type) + + d2 := b.CheckIP(netip.MustParseAddr("10.1.2.3")) + require.NotNil(t, d2, "10.1.2.3 should match range ban") + assert.Equal(t, restrict.DecisionBan, d2.Type) + + d3 := b.CheckIP(netip.MustParseAddr("5.5.5.5")) + require.NotNil(t, d3, "5.5.5.5 should have captcha") + assert.Equal(t, restrict.DecisionCaptcha, d3.Type) + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("9.9.9.9")), "unknown IP should be nil") + + // Simulate a delta update: delete one IP, add a new one. + lapi.setDelta( + []decision{{scope: "ip", value: "1.2.3.4", dtype: "ban"}}, + []decision{{scope: "ip", value: "2.3.4.5", dtype: "throttle", scenario: "crowdsecurity/http-flood"}}, + ) + + // Wait for the delta to be picked up. + require.Eventually(t, func() bool { + return b.CheckIP(netip.MustParseAddr("2.3.4.5")) != nil + }, 5*time.Second, 50*time.Millisecond, "new decision should appear") + + assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4")), "deleted decision should be gone") + + d4 := b.CheckIP(netip.MustParseAddr("2.3.4.5")) + require.NotNil(t, d4) + assert.Equal(t, restrict.DecisionThrottle, d4.Type) + + // Range ban should still be active. + assert.NotNil(t, b.CheckIP(netip.MustParseAddr("10.99.99.99"))) +} + +// Helpers + +func newTestBouncer() *Bouncer { + return &Bouncer{ + ips: make(map[netip.Addr]*restrict.CrowdSecDecision), + prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision), + logger: log.NewEntry(log.StandardLogger()), + } +} + +type decision struct { + scope string + value string + dtype string + scenario string +} + +func makeDecisions(decs ...decision) []*models.Decision { + out := make([]*models.Decision, len(decs)) + for i, d := range decs { + out[i] = &models.Decision{ + Scope: strPtr(d.scope), + Value: strPtr(d.value), + Type: strPtr(d.dtype), + Scenario: strPtr(d.scenario), + Duration: strPtr("1h"), + Origin: strPtr("cscli"), + } + } + return out +} + +func strPtr(s string) *string { return &s } + +// fakeLAPI is a minimal fake CrowdSec LAPI that serves /v1/decisions/stream. +type fakeLAPI struct { + mu sync.Mutex + initial []decision + newDelta []decision + delDelta []decision + served bool // true after the initial snapshot has been served +} + +func newFakeLAPI() *fakeLAPI { + return &fakeLAPI{} +} + +func (f *fakeLAPI) setDecisions(decs ...decision) { + f.mu.Lock() + defer f.mu.Unlock() + f.initial = decs + f.served = false +} + +func (f *fakeLAPI) setDelta(deleted, added []decision) { + f.mu.Lock() + defer f.mu.Unlock() + f.delDelta = deleted + f.newDelta = added +} + +func (f *fakeLAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/decisions/stream" { + http.NotFound(w, r) + return + } + + f.mu.Lock() + defer f.mu.Unlock() + + resp := streamResponse{} + + if !f.served { + for _, d := range f.initial { + resp.New = append(resp.New, toLAPIDecision(d)) + } + f.served = true + } else { + for _, d := range f.delDelta { + resp.Deleted = append(resp.Deleted, toLAPIDecision(d)) + } + for _, d := range f.newDelta { + resp.New = append(resp.New, toLAPIDecision(d)) + } + // Clear delta after serving once. + f.delDelta = nil + f.newDelta = nil + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) //nolint:errcheck +} + +// streamResponse mirrors the CrowdSec LAPI /v1/decisions/stream JSON structure. +type streamResponse struct { + New []*lapiDecision `json:"new"` + Deleted []*lapiDecision `json:"deleted"` +} + +type lapiDecision struct { + Duration *string `json:"duration"` + Origin *string `json:"origin"` + Scenario *string `json:"scenario"` + Scope *string `json:"scope"` + Type *string `json:"type"` + Value *string `json:"value"` +} + +func toLAPIDecision(d decision) *lapiDecision { + return &lapiDecision{ + Duration: strPtr("1h"), + Origin: strPtr("cscli"), + Scenario: strPtr(d.scenario), + Scope: strPtr(d.scope), + Type: strPtr(d.dtype), + Value: strPtr(d.value), + } +} diff --git a/proxy/internal/crowdsec/registry.go b/proxy/internal/crowdsec/registry.go new file mode 100644 index 000000000..652fb6f9f --- /dev/null +++ b/proxy/internal/crowdsec/registry.go @@ -0,0 +1,103 @@ +package crowdsec + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// Registry manages a single shared Bouncer instance with reference counting. +// The bouncer starts when the first service acquires it and stops when the +// last service releases it. +type Registry struct { + mu sync.Mutex + bouncer *Bouncer + refs map[types.ServiceID]struct{} + apiURL string + apiKey string + logger *log.Entry + cancel context.CancelFunc +} + +// NewRegistry creates a registry. The bouncer is not started until Acquire is called. +func NewRegistry(apiURL, apiKey string, logger *log.Entry) *Registry { + return &Registry{ + apiURL: apiURL, + apiKey: apiKey, + logger: logger, + refs: make(map[types.ServiceID]struct{}), + } +} + +// Available returns true when the LAPI URL and API key are configured. +func (r *Registry) Available() bool { + return r.apiURL != "" && r.apiKey != "" +} + +// Acquire registers svcID as a consumer and starts the bouncer if this is the +// first consumer. Returns the shared Bouncer (which implements the restrict +// package's CrowdSecChecker interface). Returns nil if not Available. +func (r *Registry) Acquire(svcID types.ServiceID) *Bouncer { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.Available() { + return nil + } + + if _, exists := r.refs[svcID]; exists { + return r.bouncer + } + + if r.bouncer == nil { + r.startLocked() + } + + // startLocked may fail, leaving r.bouncer nil. + if r.bouncer == nil { + return nil + } + + r.refs[svcID] = struct{}{} + return r.bouncer +} + +// Release removes svcID as a consumer. Stops the bouncer when the last +// consumer releases. +func (r *Registry) Release(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.refs, svcID) + + if len(r.refs) == 0 && r.bouncer != nil { + r.stopLocked() + } +} + +func (r *Registry) startLocked() { + b := NewBouncer(r.apiURL, r.apiKey, r.logger) + + ctx, cancel := context.WithCancel(context.Background()) + r.cancel = cancel + + if err := b.Start(ctx); err != nil { + r.logger.Errorf("failed to start CrowdSec bouncer: %v", err) + cancel() + return + } + + r.bouncer = b + r.logger.Info("CrowdSec bouncer started") +} + +func (r *Registry) stopLocked() { + r.bouncer.Stop() + r.cancel() + r.bouncer = nil + r.cancel = nil + r.logger.Info("CrowdSec bouncer stopped") +} diff --git a/proxy/internal/crowdsec/registry_test.go b/proxy/internal/crowdsec/registry_test.go new file mode 100644 index 000000000..f1567b186 --- /dev/null +++ b/proxy/internal/crowdsec/registry_test.go @@ -0,0 +1,66 @@ +package crowdsec + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRegistry_Available(t *testing.T) { + r := NewRegistry("http://localhost:8080/", "test-key", log.NewEntry(log.StandardLogger())) + assert.True(t, r.Available()) + + r2 := NewRegistry("", "", log.NewEntry(log.StandardLogger())) + assert.False(t, r2.Available()) + + r3 := NewRegistry("http://localhost:8080/", "", log.NewEntry(log.StandardLogger())) + assert.False(t, r3.Available()) +} + +func TestRegistry_Acquire_NotAvailable(t *testing.T) { + r := NewRegistry("", "", log.NewEntry(log.StandardLogger())) + b := r.Acquire("svc-1") + assert.Nil(t, b) +} + +func TestRegistry_Acquire_Idempotent(t *testing.T) { + r := newTestRegistry() + + b1 := r.Acquire("svc-1") + // Can't start without a real LAPI, but we can verify the ref tracking. + // The bouncer will be nil because Start fails, but the ref is tracked. + _ = b1 + + assert.Len(t, r.refs, 1) + + // Second acquire of same service should not add another ref. + r.Acquire("svc-1") + assert.Len(t, r.refs, 1) +} + +func TestRegistry_Release_Removes(t *testing.T) { + r := newTestRegistry() + r.refs[types.ServiceID("svc-1")] = struct{}{} + + r.Release("svc-1") + assert.Empty(t, r.refs) +} + +func TestRegistry_Release_Noop(t *testing.T) { + r := newTestRegistry() + // Releasing a service that was never acquired should not panic. + r.Release("nonexistent") + assert.Empty(t, r.refs) +} + +func newTestRegistry() *Registry { + return &Registry{ + apiURL: "http://localhost:8080/", + apiKey: "test-key", + logger: log.NewEntry(log.StandardLogger()), + refs: make(map[types.ServiceID]struct{}), + } +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 885c574bc..01b0bc8e6 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) { return } - _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT") + _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT") _, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110)) for _, item := range clients { @@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) { return } - domains := c.extractDomains(client) + services := c.extractServiceKeys(client) hasClient := "no" if hc, ok := client["has_client"].(bool); ok && hc { hasClient = "yes" @@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) { _, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n", client["account_id"], client["age"], - domains, + services, hasClient, ) } -func (c *Client) extractDomains(client map[string]any) string { - d, ok := client["domains"].([]any) +func (c *Client) extractServiceKeys(client map[string]any) string { + d, ok := client["service_keys"].([]any) if !ok || len(d) == 0 { return "-" } parts := make([]string, len(d)) - for i, domain := range d { - parts[i] = fmt.Sprint(domain) + for i, key := range d { + parts[i] = fmt.Sprint(key) } return strings.Join(parts, ", ") } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index ab75c8b72..c507cfad9 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -189,7 +189,7 @@ type indexData struct { Version string Uptime string ClientCount int - TotalDomains int + TotalServices int CertsTotal int CertsReady int CertsPending int @@ -202,7 +202,7 @@ type indexData struct { type clientData struct { AccountID string - Domains string + Services string Age string Status string } @@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b clients := h.provider.ListClientsForDebug() sortedIDs := sortedAccountIDs(clients) - totalDomains := 0 + totalServices := 0 for _, info := range clients { - totalDomains += info.DomainCount + totalServices += info.ServiceCount } var certsTotal, certsReady, certsPending, certsFailed int @@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } resp := map[string]interface{}{ - "version": version.NetbirdVersion(), - "uptime": time.Since(h.startTime).Round(time.Second).String(), - "client_count": len(clients), - "total_domains": totalDomains, - "certs_total": certsTotal, - "certs_ready": certsReady, - "certs_pending": certsPending, - "certs_failed": certsFailed, - "clients": clientsJSON, + "version": version.NetbirdVersion(), + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "total_services": totalServices, + "certs_total": certsTotal, + "certs_ready": certsReady, + "certs_pending": certsPending, + "certs_failed": certsFailed, + "clients": clientsJSON, } if len(certsPendingDomains) > 0 { resp["certs_pending_domains"] = certsPendingDomains @@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b Version: version.NetbirdVersion(), Uptime: time.Since(h.startTime).Round(time.Second).String(), ClientCount: len(clients), - TotalDomains: totalDomains, + TotalServices: totalServices, CertsTotal: certsTotal, CertsReady: certsReady, CertsPending: certsPending, @@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) @@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } h.writeJSON(w, map[string]interface{}{ @@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) @@ -409,17 +409,13 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc } pbStatus := nbstatus.ToProtoFullStatus(fullStatus) - overview := nbstatus.ConvertToStatusOutputOverview( - pbStatus, - false, - version.NetbirdVersion(), - statusFilter, - prefixNamesFilter, - prefixNamesFilterMap, - ipsFilterMap, - connectionTypeFilter, - "", - ) + overview := nbstatus.ConvertToStatusOutputOverview(pbStatus, nbstatus.ConvertOptions{ + StatusFilter: statusFilter, + PrefixNamesFilter: prefixNamesFilter, + PrefixNamesFilterMap: prefixNamesFilterMap, + IPsFilter: ipsFilterMap, + ConnectionTypeFilter: connectionTypeFilter, + }) if wantJSON { h.writeJSON(w, map[string]interface{}{ diff --git a/proxy/internal/debug/templates/clients.html b/proxy/internal/debug/templates/clients.html index 4d455b2bb..bfc25f95a 100644 --- a/proxy/internal/debug/templates/clients.html +++ b/proxy/internal/debug/templates/clients.html @@ -12,14 +12,14 @@ - + {{range .Clients}} - + diff --git a/proxy/internal/debug/templates/index.html b/proxy/internal/debug/templates/index.html index 16ab3d979..5bd25adfc 100644 --- a/proxy/internal/debug/templates/index.html +++ b/proxy/internal/debug/templates/index.html @@ -27,19 +27,19 @@
    {{range .CertsFailedDomains}}
  • {{.Domain}}: {{.Error}}
  • {{end}}
{{end}} -

Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})

+

Clients ({{.ClientCount}}) | Services ({{.TotalServices}})

{{if .Clients}}
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}
- + {{range .Clients}} - + diff --git a/proxy/internal/geolocation/download.go b/proxy/internal/geolocation/download.go new file mode 100644 index 000000000..64d515275 --- /dev/null +++ b/proxy/internal/geolocation/download.go @@ -0,0 +1,264 @@ +package geolocation + +import ( + "archive/tar" + "bufio" + "compress/gzip" + "crypto/sha256" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz" + mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256" + mmdbInnerName = "GeoLite2-City.mmdb" + + downloadTimeout = 2 * time.Minute + maxMMDBSize = 256 << 20 // 256 MB +) + +// ensureMMDB checks for an existing MMDB file in dataDir. If none is found, +// it downloads from pkgs.netbird.io with SHA256 verification. +func ensureMMDB(logger *log.Logger, dataDir string) (string, error) { + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err) + } + + pattern := filepath.Join(dataDir, mmdbGlob) + if files, _ := filepath.Glob(pattern); len(files) > 0 { + mmdbPath := files[len(files)-1] + logger.Debugf("using existing geolocation database: %s", mmdbPath) + return mmdbPath, nil + } + + logger.Info("geolocation database not found, downloading from pkgs.netbird.io") + return downloadMMDB(logger, dataDir) +} + +func downloadMMDB(logger *log.Logger, dataDir string) (string, error) { + client := &http.Client{Timeout: downloadTimeout} + + datedName, err := fetchRemoteFilename(client, mmdbTarGZURL) + if err != nil { + return "", fmt.Errorf("get remote filename: %w", err) + } + + mmdbFilename := deriveMMDBFilename(datedName) + mmdbPath := filepath.Join(dataDir, mmdbFilename) + + tmp, err := os.MkdirTemp("", "geolite-proxy-*") + if err != nil { + return "", fmt.Errorf("create temp directory: %w", err) + } + defer os.RemoveAll(tmp) + + checksumFile := filepath.Join(tmp, "checksum.sha256") + if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil { + return "", fmt.Errorf("download checksum: %w", err) + } + + expectedHash, err := readChecksumFile(checksumFile) + if err != nil { + return "", fmt.Errorf("read checksum: %w", err) + } + + tarFile := filepath.Join(tmp, datedName) + logger.Debugf("downloading geolocation database (%s)", datedName) + if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil { + return "", fmt.Errorf("download database: %w", err) + } + + if err := verifySHA256(tarFile, expectedHash); err != nil { + return "", fmt.Errorf("verify database checksum: %w", err) + } + + if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil { + return "", fmt.Errorf("extract database: %w", err) + } + + logger.Infof("geolocation database downloaded: %s", mmdbPath) + return mmdbPath, nil +} + +// deriveMMDBFilename converts a tar.gz filename to an MMDB filename. +// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb +func deriveMMDBFilename(tarName string) string { + base, _, _ := strings.Cut(tarName, ".") + if !strings.Contains(base, "_") { + return "GeoLite2-City.mmdb" + } + return base + ".mmdb" +} + +func fetchRemoteFilename(client *http.Client, url string) (string, error) { + resp, err := client.Head(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode) + } + + cd := resp.Header.Get("Content-Disposition") + if cd == "" { + return "", errors.New("no Content-Disposition header") + } + + _, params, err := mime.ParseMediaType(cd) + if err != nil { + return "", fmt.Errorf("parse Content-Disposition: %w", err) + } + + name := filepath.Base(params["filename"]) + if name == "" || name == "." { + return "", errors.New("no filename in Content-Disposition") + } + return name, nil +} + +func downloadToFile(client *http.Client, url, dest string) error { + resp, err := client.Get(url) //nolint:gosec + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + f, err := os.Create(dest) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + // Cap download at 256 MB to prevent unbounded reads from a compromised server. + if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil { + return err + } + return nil +} + +func readChecksumFile(path string) (string, error) { + f, err := os.Open(path) //nolint:gosec + if err != nil { + return "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + if scanner.Scan() { + parts := strings.Fields(scanner.Text()) + if len(parts) > 0 { + return parts[0], nil + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return "", errors.New("empty checksum file") +} + +func verifySHA256(path, expected string) error { + f, err := os.Open(path) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + + actual := fmt.Sprintf("%x", h.Sum(nil)) + if actual != expected { + return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual) + } + return nil +} + +func extractMMDBFromTarGZ(tarGZPath, destPath string) error { + f, err := os.Open(tarGZPath) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + gz, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + + if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName { + if hdr.Size < 0 || hdr.Size > maxMMDBSize { + return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize) + } + if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil { + return err + } + return nil + } + } + + return fmt.Errorf("%s not found in archive", mmdbInnerName) +} + +// extractToFileAtomic writes r to a temporary file in the same directory as +// destPath, then renames it into place so a crash never leaves a truncated file. +func extractToFileAtomic(r io.Reader, destPath string) error { + dir := filepath.Dir(destPath) + tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp") + if err != nil { + return fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmp.Name() + + if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader + if closeErr := tmp.Close(); closeErr != nil { + log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr) + } + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("write mmdb: %w", err) + } + if err := tmp.Close(); err != nil { + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("close temp file: %w", err) + } + if err := os.Rename(tmpPath, destPath); err != nil { + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("rename to %s: %w", destPath, err) + } + return nil +} diff --git a/proxy/internal/geolocation/geolocation.go b/proxy/internal/geolocation/geolocation.go new file mode 100644 index 000000000..81b02efb3 --- /dev/null +++ b/proxy/internal/geolocation/geolocation.go @@ -0,0 +1,152 @@ +// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases. +package geolocation + +import ( + "fmt" + "net/netip" + "os" + "strconv" + "sync" + + "github.com/oschwald/maxminddb-golang" + log "github.com/sirupsen/logrus" +) + +const ( + // EnvDisable disables geolocation lookups entirely when set to a truthy value. + EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION" + + mmdbGlob = "GeoLite2-City_*.mmdb" +) + +type record struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + } `maxminddb:"country"` + City struct { + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"city"` + Subdivisions []struct { + ISOCode string `maxminddb:"iso_code"` + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"subdivisions"` +} + +// Result holds the outcome of a geo lookup. +type Result struct { + CountryCode string + CityName string + SubdivisionCode string + SubdivisionName string +} + +// Lookup provides IP geolocation lookups. +type Lookup struct { + mu sync.RWMutex + db *maxminddb.Reader + logger *log.Logger +} + +// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir. +// Returns nil without error if geolocation is disabled via environment +// variable, no data directory is configured, or the download fails +// (graceful degradation: country restrictions will deny all requests). +func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) { + if isDisabledByEnv(logger) { + logger.Info("geolocation disabled via environment variable") + return nil, nil //nolint:nilnil + } + + if dataDir == "" { + return nil, nil //nolint:nilnil + } + + mmdbPath, err := ensureMMDB(logger, dataDir) + if err != nil { + logger.Warnf("geolocation database unavailable: %v", err) + logger.Warn("country-based access restrictions will deny all requests until a database is available") + return nil, nil //nolint:nilnil + } + + db, err := maxminddb.Open(mmdbPath) + if err != nil { + return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err) + } + + logger.Infof("geolocation database loaded from %s", mmdbPath) + return &Lookup{db: db, logger: logger}, nil +} + +// LookupAddr returns the country ISO code and city name for the given IP. +// Returns an empty Result if the database is nil or the lookup fails. +func (l *Lookup) LookupAddr(addr netip.Addr) Result { + if l == nil { + return Result{} + } + + l.mu.RLock() + defer l.mu.RUnlock() + + if l.db == nil { + return Result{} + } + + addr = addr.Unmap() + + var rec record + if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil { + l.logger.Debugf("geolocation lookup %s: %v", addr, err) + return Result{} + } + r := Result{ + CountryCode: rec.Country.ISOCode, + CityName: rec.City.Names.En, + } + if len(rec.Subdivisions) > 0 { + r.SubdivisionCode = rec.Subdivisions[0].ISOCode + r.SubdivisionName = rec.Subdivisions[0].Names.En + } + return r +} + +// Available reports whether the lookup has a loaded database. +func (l *Lookup) Available() bool { + if l == nil { + return false + } + l.mu.RLock() + defer l.mu.RUnlock() + return l.db != nil +} + +// Close releases the database resources. +func (l *Lookup) Close() error { + if l == nil { + return nil + } + l.mu.Lock() + defer l.mu.Unlock() + if l.db != nil { + err := l.db.Close() + l.db = nil + return err + } + return nil +} + +func isDisabledByEnv(logger *log.Logger) bool { + val := os.Getenv(EnvDisable) + if val == "" { + return false + } + disabled, err := strconv.ParseBool(val) + if err != nil { + logger.Warnf("parse %s=%q: %v", EnvDisable, val, err) + return false + } + return disabled +} diff --git a/proxy/internal/metrics/l4_metrics_test.go b/proxy/internal/metrics/l4_metrics_test.go new file mode 100644 index 000000000..055158828 --- /dev/null +++ b/proxy/internal/metrics/l4_metrics_test.go @@ -0,0 +1,69 @@ +package metrics_test + +import ( + "context" + "reflect" + "testing" + "time" + + promexporter "go.opentelemetry.io/otel/exporters/prometheus" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func newTestMetrics(t *testing.T) *metrics.Metrics { + t.Helper() + + exporter, err := promexporter.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } + return m +} + +func TestL4ServiceGauge(t *testing.T) { + m := newTestMetrics(t) + + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeUDP) + m.L4ServiceRemoved(types.ServiceModeTCP) +} + +func TestTCPRelayMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-1") + + m.TCPRelayStarted(acct) + m.TCPRelayStarted(acct) + m.TCPRelayEnded(acct, 10*time.Second, 1000, 500) + m.TCPRelayDialError(acct) + m.TCPRelayRejected(acct) +} + +func TestUDPSessionMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-2") + + m.UDPSessionStarted(acct) + m.UDPSessionStarted(acct) + m.UDPSessionEnded(acct) + m.UDPSessionDialError(acct) + m.UDPSessionRejected(acct) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200) + m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150) +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 954020f77..573485625 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -1,64 +1,212 @@ package metrics import ( + "context" "net/http" - "strconv" + "sync" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/responsewriter" + "github.com/netbirdio/netbird/proxy/internal/types" ) +// Metrics collects OpenTelemetry metrics for the proxy. type Metrics struct { - requestsTotal prometheus.Counter - activeRequests prometheus.Gauge - configuredDomains prometheus.Gauge - pathsPerDomain *prometheus.GaugeVec - requestDuration *prometheus.HistogramVec - backendDuration *prometheus.HistogramVec + ctx context.Context + requestsTotal metric.Int64Counter + activeRequests metric.Int64UpDownCounter + configuredDomains metric.Int64UpDownCounter + totalPaths metric.Int64UpDownCounter + requestDuration metric.Int64Histogram + backendDuration metric.Int64Histogram + certificateIssueDuration metric.Int64Histogram + + // L4 service-level metrics. + l4Services metric.Int64UpDownCounter + + // L4 TCP connection-level metrics. + tcpActiveConns metric.Int64UpDownCounter + tcpConnsTotal metric.Int64Counter + tcpConnDuration metric.Int64Histogram + tcpBytesTotal metric.Int64Counter + + // L4 UDP session-level metrics. + udpActiveSess metric.Int64UpDownCounter + udpSessionsTotal metric.Int64Counter + udpPacketsTotal metric.Int64Counter + udpBytesTotal metric.Int64Counter + + mappingsMux sync.Mutex + mappingPaths map[string]int } -func New(reg prometheus.Registerer) *Metrics { - promFactory := promauto.With(reg) - return &Metrics{ - requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{ - Name: "netbird_proxy_requests_total", - Help: "Total number of requests made to the netbird proxy", - }), - activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{ - Name: "netbird_proxy_active_requests_count", - Help: "Current in-flight requests handled by the netbird proxy", - }), - configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{ - Name: "netbird_proxy_domains_count", - Help: "Current number of domains configured on the netbird proxy", - }), - pathsPerDomain: promFactory.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "netbird_proxy_paths_count", - Help: "Current number of paths configured on the netbird proxy labelled by domain", - }, - []string{"domain"}, - ), - requestDuration: promFactory.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "netbird_proxy_request_duration_seconds", - Help: "Duration of requests made to the netbird proxy", - Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, - }, - []string{"status", "size", "method", "host", "path"}, - ), - backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{ - Name: "netbird_proxy_backend_duration_seconds", - Help: "Duration of peer round trip time from the netbird proxy", - Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, - }, - []string{"status", "size", "method", "host", "path"}, - ), +// New creates a Metrics instance using the given OpenTelemetry meter. +func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { + m := &Metrics{ + ctx: ctx, + mappingPaths: make(map[string]int), } + + if err := m.initHTTPMetrics(meter); err != nil { + return nil, err + } + if err := m.initL4Metrics(meter); err != nil { + return nil, err + } + + return m, nil +} + +func (m *Metrics) initHTTPMetrics(meter metric.Meter) error { + var err error + + m.requestsTotal, err = meter.Int64Counter( + "proxy.http.request.counter", + metric.WithUnit("1"), + metric.WithDescription("Total number of requests made to the netbird proxy"), + ) + if err != nil { + return err + } + + m.activeRequests, err = meter.Int64UpDownCounter( + "proxy.http.active_requests", + metric.WithUnit("1"), + metric.WithDescription("Current in-flight requests handled by the netbird proxy"), + ) + if err != nil { + return err + } + + m.configuredDomains, err = meter.Int64UpDownCounter( + "proxy.domains.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of domains configured on the netbird proxy"), + ) + if err != nil { + return err + } + + m.totalPaths, err = meter.Int64UpDownCounter( + "proxy.paths.count", + metric.WithUnit("1"), + metric.WithDescription("Total number of paths configured on the netbird proxy"), + ) + if err != nil { + return err + } + + m.requestDuration, err = meter.Int64Histogram( + "proxy.http.request.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of requests made to the netbird proxy"), + ) + if err != nil { + return err + } + + m.backendDuration, err = meter.Int64Histogram( + "proxy.backend.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of peer round trip time from the netbird proxy"), + ) + if err != nil { + return err + } + + m.certificateIssueDuration, err = meter.Int64Histogram( + "proxy.certificate.issue.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of ACME certificate issuance"), + ) + return err +} + +func (m *Metrics) initL4Metrics(meter metric.Meter) error { + var err error + + m.l4Services, err = meter.Int64UpDownCounter( + "proxy.l4.services.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"), + ) + if err != nil { + return err + } + + m.tcpActiveConns, err = meter.Int64UpDownCounter( + "proxy.tcp.active_connections", + metric.WithUnit("1"), + metric.WithDescription("Current number of active TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpConnsTotal, err = meter.Int64Counter( + "proxy.tcp.connections.total", + metric.WithUnit("1"), + metric.WithDescription("Total TCP/TLS relay connections by result and account"), + ) + if err != nil { + return err + } + + m.tcpConnDuration, err = meter.Int64Histogram( + "proxy.tcp.connection.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpBytesTotal, err = meter.Int64Counter( + "proxy.tcp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"), + ) + if err != nil { + return err + } + + m.udpActiveSess, err = meter.Int64UpDownCounter( + "proxy.udp.active_sessions", + metric.WithUnit("1"), + metric.WithDescription("Current number of active UDP relay sessions"), + ) + if err != nil { + return err + } + + m.udpSessionsTotal, err = meter.Int64Counter( + "proxy.udp.sessions.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP relay sessions by result and account"), + ) + if err != nil { + return err + } + + m.udpPacketsTotal, err = meter.Int64Counter( + "proxy.udp.packets.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP packets relayed by direction"), + ) + if err != nil { + return err + } + + m.udpBytesTotal, err = meter.Int64Counter( + "proxy.udp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through UDP relay by direction"), + ) + return err } type responseInterceptor struct { @@ -78,25 +226,28 @@ func (w *responseInterceptor) Write(b []byte) (int, error) { return size, err } +// Unwrap returns the underlying ResponseWriter so http.ResponseController +// can reach through to the original writer for Hijack/Flush operations. +func (w *responseInterceptor) Unwrap() http.ResponseWriter { + return w.PassthroughWriter +} + +// Middleware wraps an HTTP handler with request metrics. func (m *Metrics) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - m.requestsTotal.Inc() - m.activeRequests.Inc() + m.requestsTotal.Add(m.ctx, 1) + m.activeRequests.Add(m.ctx, 1) interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)} start := time.Now() - next.ServeHTTP(interceptor, r) - duration := time.Since(start) + defer func() { + duration := time.Since(start) + m.activeRequests.Add(m.ctx, -1) + m.requestDuration.Record(m.ctx, duration.Milliseconds()) + }() - m.activeRequests.Desc() - m.requestDuration.With(prometheus.Labels{ - "status": strconv.Itoa(interceptor.status), - "size": strconv.Itoa(interceptor.size), - "method": r.Method, - "host": r.Host, - "path": r.URL.Path, - }).Observe(duration.Seconds()) + next.ServeHTTP(interceptor, r) }) } @@ -106,46 +257,133 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } +// RoundTripper wraps an http.RoundTripper with backend duration metrics. func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { - labels := prometheus.Labels{ - "method": req.Method, - "host": req.Host, - // Fill potentially empty labels with default values to avoid cardinality issues. - "path": "/", - "status": "0", - "size": "0", - } - if req.URL != nil { - labels["path"] = req.URL.Path - } - start := time.Now() res, err := next.RoundTrip(req) duration := time.Since(start) - // Not all labels will be available if there was an error. - if res != nil { - labels["status"] = strconv.Itoa(res.StatusCode) - labels["size"] = strconv.Itoa(int(res.ContentLength)) - } - - m.backendDuration.With(labels).Observe(duration.Seconds()) + m.backendDuration.Record(m.ctx, duration.Milliseconds()) return res, err }) } +// AddMapping records that a domain mapping was added. func (m *Metrics) AddMapping(mapping proxy.Mapping) { - m.configuredDomains.Inc() - m.pathsPerDomain.With(prometheus.Labels{ - "domain": mapping.Host, - }).Set(float64(len(mapping.Paths))) + m.mappingsMux.Lock() + defer m.mappingsMux.Unlock() + + newPathCount := len(mapping.Paths) + oldPathCount, exists := m.mappingPaths[mapping.Host] + + if !exists { + m.configuredDomains.Add(m.ctx, 1) + } + + pathDelta := newPathCount - oldPathCount + if pathDelta != 0 { + m.totalPaths.Add(m.ctx, int64(pathDelta)) + } + + m.mappingPaths[mapping.Host] = newPathCount } +// RemoveMapping records that a domain mapping was removed. func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { - m.configuredDomains.Dec() - m.pathsPerDomain.With(prometheus.Labels{ - "domain": mapping.Host, - }).Set(0) + m.mappingsMux.Lock() + defer m.mappingsMux.Unlock() + + oldPathCount, exists := m.mappingPaths[mapping.Host] + if !exists { + return + } + + m.configuredDomains.Add(m.ctx, -1) + m.totalPaths.Add(m.ctx, -int64(oldPathCount)) + + delete(m.mappingPaths, mapping.Host) +} + +// RecordCertificateIssuance records the duration of a certificate issuance. +func (m *Metrics) RecordCertificateIssuance(duration time.Duration) { + m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds()) +} + +// L4ServiceAdded increments the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// L4ServiceRemoved decrements the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// TCPRelayStarted records a new TCP relay connection starting. +func (m *Metrics) TCPRelayStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration. +func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct)) + m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct)) + m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend"))) + m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client"))) +} + +// TCPRelayDialError records a dial failure for a TCP relay. +func (m *Metrics) TCPRelayDialError(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// TCPRelayRejected records a rejected TCP relay (semaphore full). +func (m *Metrics) TCPRelayRejected(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPSessionStarted records a new UDP session starting. +func (m *Metrics) UDPSessionStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// UDPSessionEnded records a UDP session ending. +func (m *Metrics) UDPSessionEnded(accountID types.AccountID) { + m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID)))) +} + +// UDPSessionDialError records a dial failure for a UDP session. +func (m *Metrics) UDPSessionDialError(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// UDPSessionRejected records a rejected UDP session (limit or rate limited). +func (m *Metrics) UDPSessionRejected(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes. +func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) { + dir := attribute.String("direction", string(direction)) + m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir)) + m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir)) } diff --git a/proxy/internal/metrics/metrics_test.go b/proxy/internal/metrics/metrics_test.go index 31e00ae64..f81072eda 100644 --- a/proxy/internal/metrics/metrics_test.go +++ b/proxy/internal/metrics/metrics_test.go @@ -1,13 +1,17 @@ package metrics_test import ( + "context" "net/http" "net/url" + "reflect" "testing" "github.com/google/go-cmp/cmp" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/sdk/metric" + "github.com/netbirdio/netbird/proxy/internal/metrics" - "github.com/prometheus/client_golang/prometheus" ) type testRoundTripper struct { @@ -47,7 +51,19 @@ func TestMetrics_RoundTripper(t *testing.T) { }, } - m := metrics.New(prometheus.NewRegistry()) + exporter, err := prometheus.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } for name, test := range tests { t.Run(name, func(t *testing.T) { diff --git a/proxy/internal/netutil/errors.go b/proxy/internal/netutil/errors.go new file mode 100644 index 000000000..ff24e33d4 --- /dev/null +++ b/proxy/internal/netutil/errors.go @@ -0,0 +1,40 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + "syscall" +) + +// ValidatePort converts an int32 proto port to uint16, returning an error +// if the value is out of the valid 1–65535 range. +func ValidatePort(port int32) (uint16, error) { + if port <= 0 || port > math.MaxUint16 { + return 0, fmt.Errorf("invalid port %d: must be 1–65535", port) + } + return uint16(port), nil +} + +// IsExpectedError returns true for errors that are normal during +// connection teardown and should not be logged as warnings. +func IsExpectedError(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} + +// IsTimeout checks whether the error is a network timeout. +func IsTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/proxy/internal/netutil/errors_test.go b/proxy/internal/netutil/errors_test.go new file mode 100644 index 000000000..7d6be10ff --- /dev/null +++ b/proxy/internal/netutil/errors_test.go @@ -0,0 +1,92 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + port int32 + want uint16 + wantErr bool + }{ + {"valid min", 1, 1, false}, + {"valid mid", 8080, 8080, false}, + {"valid max", 65535, 65535, false}, + {"zero", 0, 0, true}, + {"negative", -1, 0, true}, + {"too large", 65536, 0, true}, + {"way too large", 100000, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidatePort(tt.port) + if tt.wantErr { + assert.Error(t, err) + assert.Zero(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestIsExpectedError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net.ErrClosed", net.ErrClosed, true}, + {"context.Canceled", context.Canceled, true}, + {"io.EOF", io.EOF, true}, + {"ECONNRESET", syscall.ECONNRESET, true}, + {"EPIPE", syscall.EPIPE, true}, + {"ECONNABORTED", syscall.ECONNABORTED, true}, + {"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true}, + {"unexpected EOF", io.ErrUnexpectedEOF, false}, + {"generic error", errors.New("something"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsExpectedError(tt.err)) + }) + } +} + +type timeoutErr struct{ timeout bool } + +func (e *timeoutErr) Error() string { return "timeout" } +func (e *timeoutErr) Timeout() bool { return e.timeout } +func (e *timeoutErr) Temporary() bool { return false } + +func TestIsTimeout(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net timeout", &timeoutErr{timeout: true}, true}, + {"net non-timeout", &timeoutErr{timeout: false}, false}, + {"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true}, + {"generic error", errors.New("not a timeout"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTimeout(tt.err)) + }) + } +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 22ebbf371..a888ad9ed 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -2,6 +2,8 @@ package proxy import ( "context" + "maps" + "net/netip" "sync" "github.com/netbirdio/netbird/proxy/internal/types" @@ -10,8 +12,6 @@ import ( type requestContextKey string const ( - serviceIdKey requestContextKey = "serviceId" - accountIdKey requestContextKey = "accountId" capturedDataKey requestContextKey = "capturedData" ) @@ -46,112 +46,135 @@ func (o ResponseOrigin) String() string { // to pass data back up the middleware chain. type CapturedData struct { mu sync.RWMutex - RequestID string - ServiceId string - AccountId types.AccountID - Origin ResponseOrigin - ClientIP string - UserID string - AuthMethod string + requestID string + serviceID types.ServiceID + accountID types.AccountID + origin ResponseOrigin + clientIP netip.Addr + userID string + authMethod string + metadata map[string]string } -// GetRequestID safely gets the request ID +// NewCapturedData creates a CapturedData with the given request ID. +func NewCapturedData(requestID string) *CapturedData { + return &CapturedData{requestID: requestID} +} + +// GetRequestID returns the request ID. func (c *CapturedData) GetRequestID() string { c.mu.RLock() defer c.mu.RUnlock() - return c.RequestID + return c.requestID } -// SetServiceId safely sets the service ID -func (c *CapturedData) SetServiceId(serviceId string) { +// SetServiceID sets the service ID. +func (c *CapturedData) SetServiceID(serviceID types.ServiceID) { c.mu.Lock() defer c.mu.Unlock() - c.ServiceId = serviceId + c.serviceID = serviceID } -// GetServiceId safely gets the service ID -func (c *CapturedData) GetServiceId() string { +// GetServiceID returns the service ID. +func (c *CapturedData) GetServiceID() types.ServiceID { c.mu.RLock() defer c.mu.RUnlock() - return c.ServiceId + return c.serviceID } -// SetAccountId safely sets the account ID -func (c *CapturedData) SetAccountId(accountId types.AccountID) { +// SetAccountID sets the account ID. +func (c *CapturedData) SetAccountID(accountID types.AccountID) { c.mu.Lock() defer c.mu.Unlock() - c.AccountId = accountId + c.accountID = accountID } -// GetAccountId safely gets the account ID -func (c *CapturedData) GetAccountId() types.AccountID { +// GetAccountID returns the account ID. +func (c *CapturedData) GetAccountID() types.AccountID { c.mu.RLock() defer c.mu.RUnlock() - return c.AccountId + return c.accountID } -// SetOrigin safely sets the response origin +// SetOrigin sets the response origin. func (c *CapturedData) SetOrigin(origin ResponseOrigin) { c.mu.Lock() defer c.mu.Unlock() - c.Origin = origin + c.origin = origin } -// GetOrigin safely gets the response origin +// GetOrigin returns the response origin. func (c *CapturedData) GetOrigin() ResponseOrigin { c.mu.RLock() defer c.mu.RUnlock() - return c.Origin + return c.origin } -// SetClientIP safely sets the resolved client IP. -func (c *CapturedData) SetClientIP(ip string) { +// SetClientIP sets the resolved client IP. +func (c *CapturedData) SetClientIP(ip netip.Addr) { c.mu.Lock() defer c.mu.Unlock() - c.ClientIP = ip + c.clientIP = ip } -// GetClientIP safely gets the resolved client IP. -func (c *CapturedData) GetClientIP() string { +// GetClientIP returns the resolved client IP. +func (c *CapturedData) GetClientIP() netip.Addr { c.mu.RLock() defer c.mu.RUnlock() - return c.ClientIP + return c.clientIP } -// SetUserID safely sets the authenticated user ID. +// SetUserID sets the authenticated user ID. func (c *CapturedData) SetUserID(userID string) { c.mu.Lock() defer c.mu.Unlock() - c.UserID = userID + c.userID = userID } -// GetUserID safely gets the authenticated user ID. +// GetUserID returns the authenticated user ID. func (c *CapturedData) GetUserID() string { c.mu.RLock() defer c.mu.RUnlock() - return c.UserID + return c.userID } -// SetAuthMethod safely sets the authentication method used. +// SetAuthMethod sets the authentication method used. func (c *CapturedData) SetAuthMethod(method string) { c.mu.Lock() defer c.mu.Unlock() - c.AuthMethod = method + c.authMethod = method } -// GetAuthMethod safely gets the authentication method used. +// GetAuthMethod returns the authentication method used. func (c *CapturedData) GetAuthMethod() string { c.mu.RLock() defer c.mu.RUnlock() - return c.AuthMethod + return c.authMethod } -// WithCapturedData adds a CapturedData struct to the context +// SetMetadata sets a key-value pair in the metadata map. +func (c *CapturedData) SetMetadata(key, value string) { + c.mu.Lock() + defer c.mu.Unlock() + if c.metadata == nil { + c.metadata = make(map[string]string) + } + c.metadata[key] = value +} + +// GetMetadata returns a copy of the metadata map. +func (c *CapturedData) GetMetadata() map[string]string { + c.mu.RLock() + defer c.mu.RUnlock() + return maps.Clone(c.metadata) +} + +// WithCapturedData adds a CapturedData struct to the context. func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) } -// CapturedDataFromContext retrieves the CapturedData from context +// CapturedDataFromContext retrieves the CapturedData from context. func CapturedDataFromContext(ctx context.Context) *CapturedData { v := ctx.Value(capturedDataKey) data, ok := v.(*CapturedData) @@ -160,28 +183,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData { } return data } - -func withServiceId(ctx context.Context, serviceId string) context.Context { - return context.WithValue(ctx, serviceIdKey, serviceId) -} - -func ServiceIdFromContext(ctx context.Context) string { - v := ctx.Value(serviceIdKey) - serviceId, ok := v.(string) - if !ok { - return "" - } - return serviceId -} -func withAccountId(ctx context.Context, accountId types.AccountID) context.Context { - return context.WithValue(ctx, accountIdKey, accountId) -} - -func AccountIdFromContext(ctx context.Context) types.AccountID { - v := ctx.Value(accountIdKey) - accountId, ok := v.(types.AccountID) - if !ok { - return "" - } - return accountId -} diff --git a/proxy/internal/proxy/proxy_bench_test.go b/proxy/internal/proxy/proxy_bench_test.go index b7526e26b..b59ef75c0 100644 --- a/proxy/internal/proxy/proxy_bench_test.go +++ b/proxy/internal/proxy/proxy_bench_test.go @@ -25,13 +25,15 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil) rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", - Paths: map[string]*url.URL{ + Paths: map[string]*proxy.PathTarget{ "/": { - Scheme: "http", - Host: "10.0.0.1:8080", + URL: &url.URL{ + Scheme: "http", + Host: "10.0.0.1:8080", + }, }, }, }) @@ -64,13 +66,15 @@ func BenchmarkServeHTTPHostCount(b *testing.B) { target = id } rp.AddMapping(proxy.Mapping{ - ID: id, + ID: types.ServiceID(id), AccountID: types.AccountID(rand.Text()), Host: host, - Paths: map[string]*url.URL{ + Paths: map[string]*proxy.PathTarget{ "/": { - Scheme: "http", - Host: "10.0.0.1:8080", + URL: &url.URL{ + Scheme: "http", + Host: "10.0.0.1:8080", + }, }, }, }) @@ -100,19 +104,21 @@ func BenchmarkServeHTTPPathCount(b *testing.B) { b.Fatal(err) } - paths := make(map[string]*url.URL, pathCount) + paths := make(map[string]*proxy.PathTarget, pathCount) for i := range pathCount { path := "/" + rand.Text() if int64(i) == targetIndex.Int64() { target = path } - paths[path] = &url.URL{ - Scheme: "http", - Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i), + paths[path] = &proxy.PathTarget{ + URL: &url.URL{ + Scheme: "http", + Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i), + }, } } rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: paths, diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index 16607689a..246851d24 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" ) @@ -65,28 +66,40 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Set the serviceId in the context for later retrieval. - ctx := withServiceId(r.Context(), result.serviceID) - // Set the accountId in the context for later retrieval (for middleware). - ctx = withAccountId(ctx, result.accountID) - // Set the accountId in the context for the roundtripper to use. + ctx := r.Context() + // Set the account ID in the context for the roundtripper to use. ctx = roundtrip.WithAccountID(ctx, result.accountID) - // Also populate captured data if it exists (allows middleware to read after handler completes). + // Populate captured data if it exists (allows middleware to read after handler completes). // This solves the problem of passing data UP the middleware chain: we put a mutable struct // pointer in the context, and mutate the struct here so outer middleware can read it. if capturedData := CapturedDataFromContext(ctx); capturedData != nil { - capturedData.SetServiceId(result.serviceID) - capturedData.SetAccountId(result.accountID) + capturedData.SetServiceID(result.serviceID) + capturedData.SetAccountID(result.accountID) + } + + pt := result.target + + if pt.SkipTLSVerify { + ctx = roundtrip.WithSkipTLSVerify(ctx) + } + if pt.RequestTimeout > 0 { + ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) + } + + rewriteMatchedPath := result.matchedPath + if pt.PathRewrite == PathRewritePreserve { + rewriteMatchedPath = "" } rp := &httputil.ReverseProxy{ - Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader), - Transport: p.transport, - ErrorHandler: proxyErrorHandler, + Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders), + Transport: p.transport, + FlushInterval: -1, + ErrorHandler: p.proxyErrorHandler, } if result.rewriteRedirects { - rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose + rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose } rp.ServeHTTP(w, r.WithContext(ctx)) } @@ -96,16 +109,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // forwarding headers and stripping proxy authentication credentials. // When passHostHeader is true, the original client Host header is preserved // instead of being rewritten to the backend's address. -func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) { +// The pathRewrite parameter controls how the request path is transformed. +func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) { return func(r *httputil.ProxyRequest) { - // Strip the matched path prefix from the incoming request path before - // SetURL joins it with the target's base path, avoiding path duplication. - if matchedPath != "" && matchedPath != "/" { - r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath) - if r.Out.URL.Path == "" { - r.Out.URL.Path = "/" + switch pathRewrite { + case PathRewritePreserve: + // Keep the full original request path as-is. + default: + if matchedPath != "" && matchedPath != "/" { + // Strip the matched path prefix from the incoming request path before + // SetURL joins it with the target's base path, avoiding path duplication. + r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath) + if r.Out.URL.Path == "" { + r.Out.URL.Path = "/" + } + r.Out.URL.RawPath = "" } - r.Out.URL.RawPath = "" } r.SetURL(target) @@ -115,9 +134,17 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Host = target.Host } - clientIP := extractClientIP(r.In.RemoteAddr) + for _, h := range stripAuthHeaders { + r.Out.Header.Del(h) + } - if IsTrustedProxy(clientIP, p.trustedProxies) { + for k, v := range customHeaders { + r.Out.Header.Set(k, v) + } + + clientIP := extractHostIP(r.In.RemoteAddr) + + if isTrustedAddr(clientIP, p.trustedProxies) { p.setTrustedForwardingHeaders(r, clientIP) } else { p.setUntrustedForwardingHeaders(r, clientIP) @@ -187,12 +214,14 @@ func normalizeHost(u *url.URL) string { // setTrustedForwardingHeaders appends to the existing forwarding header chain // and preserves upstream-provided headers when the direct connection is from // a trusted proxy. -func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() + // Append the direct connection IP to the existing X-Forwarded-For chain. if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" { - r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP) + r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr) } else { - r.Out.Header.Set("X-Forwarded-For", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) } // Preserve upstream X-Real-IP if present; otherwise resolve through the chain. @@ -200,7 +229,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli r.Out.Header.Set("X-Real-IP", realIP) } else { resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies) - r.Out.Header.Set("X-Real-IP", resolved) + r.Out.Header.Set("X-Real-IP", resolved.String()) } // Preserve upstream X-Forwarded-Host if present. @@ -230,10 +259,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli // sets them fresh based on the direct connection. This is the default // behavior when no trusted proxies are configured or the direct connection // is from an untrusted source. -func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) - r.Out.Header.Set("X-Forwarded-For", clientIP) - r.Out.Header.Set("X-Real-IP", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) + r.Out.Header.Set("X-Real-IP", ipStr) r.Out.Header.Set("X-Forwarded-Host", r.In.Host) r.Out.Header.Set("X-Forwarded-Proto", proto) r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) @@ -261,16 +291,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) { } } -// extractClientIP extracts the IP address from an http.Request.RemoteAddr -// which is always in host:port format. -func extractClientIP(remoteAddr string) string { - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return remoteAddr - } - return ip -} - // extractForwardedPort returns the port from the Host header if present, // otherwise defaults to the standard port for the resolved protocol. func extractForwardedPort(host, resolvedProto string) string { @@ -286,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string { // proxyErrorHandler handles errors from the reverse proxy and serves // user-friendly error pages instead of raw error responses. -func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { +func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { if cd := CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(OriginProxyError) } @@ -294,16 +314,18 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { clientIP := getClientIP(r) title, message, code, status := classifyProxyError(err) - log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", + p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err) web.ServeErrorPage(w, r, code, title, message, requestID, status) } -// getClientIP retrieves the resolved client IP from context. +// getClientIP retrieves the resolved client IP string from context. func getClientIP(r *http.Request) string { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { - return capturedData.GetClientIP() + if ip := capturedData.GetClientIP(); ip.IsValid() { + return ip.String() + } } return "" } diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index f7f231db4..c53307837 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} t.Run("rewrites host to backend by default", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") rewrite(pr) @@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { }) t.Run("preserves original host when passHostHeader is true", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "", true) + rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") rewrite(pr) @@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { func TestRewriteFunc_XForwardedForStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") @@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000") rewrite(pr) @@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000") rewrite(pr) @@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("auto detects https from TLS", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("auto detects http without TLS", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("forced proto overrides TLS detection", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "https"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") // No TLS, but forced to https @@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("forced http proto", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "http"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { func TestRewriteFunc_SessionCookieStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("strips nb_session cookie", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") @@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) { func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("strips session_token query parameter", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000") @@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("rewrites URL to target with path prefix", func(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080/app") - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000") rewrite(pr) @@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) { target, _ := url.Parse("https://backend.example.org:443/app") - rewrite := p.rewriteFunc(target, "/app", false) + rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000") rewrite(pr) @@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("strips matched prefix and preserves subpath", func(t *testing.T) { target, _ := url.Parse("https://backend.example.org:443/app") - rewrite := p.rewriteFunc(target, "/app", false) + rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000") rewrite(pr) @@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { }) } -func TestExtractClientIP(t *testing.T) { +func TestExtractHostIP(t *testing.T) { tests := []struct { name string remoteAddr string - expected string + expected netip.Addr }{ - {"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"}, - {"IPv6 with port", "[::1]:12345", "::1"}, - {"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"}, - {"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"}, - {"IPv6 without brackets fallback", "::1", "::1"}, - {"empty string fallback", "", ""}, - {"public IP", "203.0.113.50:9999", "203.0.113.50"}, + {"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")}, + {"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")}, + {"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")}, + {"empty string fallback", "", netip.Addr{}}, + {"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr)) + assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr)) }) } } @@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("appends to X-Forwarded-For", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Real-IP", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") @@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Host", "original.example.com") @@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Proto", "https") @@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Port", "8443") @@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") @@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("empty trusted list behaves as untrusted", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) { // Management builds: path="/heise", target="https://heise.de:443/heise" target, _ := url.Parse("https://heise.de:443/heise") - rewrite := p.rewriteFunc(target, "/heise", false) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("subpath under prefix also preserved", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443/heise") - rewrite := p.rewriteFunc(target, "/heise", false) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") rewrite(pr) @@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { // What the behavior WOULD be if target URL had no path (true stripping) t.Run("target without path prefix gives true stripping", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443") - rewrite := p.rewriteFunc(target, "/heise", false) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443") - rewrite := p.rewriteFunc(target, "/heise", false) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") rewrite(pr) @@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { // Root path "/" — no stripping expected t.Run("root path forwards full request path unchanged", func(t *testing.T) { target, _ := url.Parse("https://backend.example.com:443/") - rewrite := p.rewriteFunc(target, "/", false) + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -546,6 +546,109 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { }) } +func TestRewriteFunc_PreservePath(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto"} + target, _ := url.Parse("http://backend.internal:8080") + + t.Run("preserve keeps full request path", func(t *testing.T) { + rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil) + pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000") + + rewrite(pr) + + assert.Equal(t, "/api/users/123", pr.Out.URL.Path, + "preserve should keep the full original request path") + }) + + t.Run("preserve with root matchedPath", func(t *testing.T) { + rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil) + pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000") + + rewrite(pr) + + assert.Equal(t, "/anything", pr.Out.URL.Path) + }) +} + +func TestRewriteFunc_CustomHeaders(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto"} + target, _ := url.Parse("http://backend.internal:8080") + + t.Run("injects custom headers", func(t *testing.T) { + headers := map[string]string{ + "X-Custom-Auth": "token-abc", + "X-Env": "production", + } + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + + rewrite(pr) + + assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth")) + assert.Equal(t, "production", pr.Out.Header.Get("X-Env")) + }) + + t.Run("nil customHeaders is fine", func(t *testing.T) { + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + + rewrite(pr) + + assert.Equal(t, "backend.internal:8080", pr.Out.Host) + }) + + t.Run("custom headers override existing request headers", func(t *testing.T) { + headers := map[string]string{"X-Override": "new-value"} + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + pr.In.Header.Set("X-Override", "old-value") + + rewrite(pr) + + assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override")) + }) +} + +func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto"} + target, _ := url.Parse("http://backend.internal:8080") + + t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) { + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"}) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + pr.In.Header.Set("Authorization", "Bearer proxy-token") + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped") + }) + + t.Run("custom Authorization replaces incoming", func(t *testing.T) { + headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="} + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"}) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + pr.In.Header.Set("Authorization", "Bearer proxy-token") + + rewrite(pr) + + assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"), + "backend Authorization from custom headers should be set") + }) +} + +func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto"} + target, _ := url.Parse("http://backend.internal:8080") + + rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil) + pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000") + + rewrite(pr) + + assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path") + assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set") +} + func TestRewriteLocationFunc(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} } diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 6f5829ebb..fe470cf01 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -6,26 +6,53 @@ import ( "net/url" "sort" "strings" + "time" "github.com/netbirdio/netbird/proxy/internal/types" ) +// PathRewriteMode controls how the request path is rewritten before forwarding. +type PathRewriteMode int + +const ( + // PathRewriteDefault strips the matched prefix and joins with the target path. + PathRewriteDefault PathRewriteMode = iota + // PathRewritePreserve keeps the full original request path as-is. + PathRewritePreserve +) + +// PathTarget holds a backend URL and per-target behavioral options. +type PathTarget struct { + URL *url.URL + SkipTLSVerify bool + RequestTimeout time.Duration + PathRewrite PathRewriteMode + CustomHeaders map[string]string +} + +// Mapping describes how a domain is routed by the HTTP reverse proxy. type Mapping struct { - ID string + ID types.ServiceID AccountID types.AccountID Host string - Paths map[string]*url.URL + Paths map[string]*PathTarget PassHostHeader bool RewriteRedirects bool + // StripAuthHeaders are header names used for header-based auth. + // These headers are stripped from requests before forwarding. + StripAuthHeaders []string + // sortedPaths caches the paths sorted by length (longest first). + sortedPaths []string } type targetResult struct { - url *url.URL + target *PathTarget matchedPath string - serviceID string + serviceID types.ServiceID accountID types.AccountID passHostHeader bool rewriteRedirects bool + stripAuthHeaders []string } func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) { @@ -44,26 +71,22 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo return targetResult{}, false } - // Sort paths by length (longest first) in a naive attempt to match the most specific route first. - paths := make([]string, 0, len(m.Paths)) - for path := range m.Paths { - paths = append(paths, path) - } - sort.Slice(paths, func(i, j int) bool { - return len(paths[i]) > len(paths[j]) - }) - - for _, path := range paths { + for _, path := range m.sortedPaths { if strings.HasPrefix(req.URL.Path, path) { - target := m.Paths[path] - p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target) + pt := m.Paths[path] + if pt == nil || pt.URL == nil { + p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path) + continue + } + p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL) return targetResult{ - url: target, + target: pt, matchedPath: path, serviceID: m.ID, accountID: m.AccountID, passHostHeader: m.PassHostHeader, rewriteRedirects: m.RewriteRedirects, + stripAuthHeaders: m.StripAuthHeaders, }, true } } @@ -71,14 +94,30 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo return targetResult{}, false } +// AddMapping registers a host-to-backend mapping for the reverse proxy. func (p *ReverseProxy) AddMapping(m Mapping) { + // Sort paths longest-first to match the most specific route first. + paths := make([]string, 0, len(m.Paths)) + for path := range m.Paths { + paths = append(paths, path) + } + sort.Slice(paths, func(i, j int) bool { + return len(paths[i]) > len(paths[j]) + }) + m.sortedPaths = paths + p.mappingsMux.Lock() defer p.mappingsMux.Unlock() p.mappings[m.Host] = m } -func (p *ReverseProxy) RemoveMapping(m Mapping) { +// RemoveMapping removes the mapping for the given host and reports whether it existed. +func (p *ReverseProxy) RemoveMapping(m Mapping) bool { p.mappingsMux.Lock() defer p.mappingsMux.Unlock() + if _, ok := p.mappings[m.Host]; !ok { + return false + } delete(p.mappings, m.Host) + return true } diff --git a/proxy/internal/proxy/trustedproxy.go b/proxy/internal/proxy/trustedproxy.go index ad9a5b6c0..0fe693f90 100644 --- a/proxy/internal/proxy/trustedproxy.go +++ b/proxy/internal/proxy/trustedproxy.go @@ -7,21 +7,11 @@ import ( // IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes. func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { - if len(trusted) == 0 { - return false - } - addr, err := netip.ParseAddr(ipStr) - if err != nil { + if err != nil || len(trusted) == 0 { return false } - - for _, prefix := range trusted { - if prefix.Contains(addr) { - return true - } - } - return false + return isTrustedAddr(addr.Unmap(), trusted) } // ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list. @@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { // // If the trusted list is empty or remoteAddr is not trusted, it returns the // remoteAddr IP directly (ignoring any forwarding headers). -func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { - remoteIP := extractClientIP(remoteAddr) +func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr { + remoteIP := extractHostIP(remoteAddr) - if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) { + if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) { return remoteIP } @@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { if ip == "" { continue } - if !IsTrustedProxy(ip, trusted) { - return ip + addr, err := netip.ParseAddr(ip) + if err != nil { + continue + } + addr = addr.Unmap() + if !isTrustedAddr(addr, trusted) { + return addr } } // All IPs in XFF are trusted; return the leftmost as best guess. if first := strings.TrimSpace(parts[0]); first != "" { - return first + if addr, err := netip.ParseAddr(first); err == nil { + return addr.Unmap() + } } return remoteIP } + +// extractHostIP parses the IP from a host:port string and returns it unmapped. +func extractHostIP(hostPort string) netip.Addr { + if ap, err := netip.ParseAddrPort(hostPort); err == nil { + return ap.Addr().Unmap() + } + if addr, err := netip.ParseAddr(hostPort); err == nil { + return addr.Unmap() + } + return netip.Addr{} +} + +// isTrustedAddr checks if the given address falls within any of the trusted prefixes. +func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool { + if !addr.IsValid() { + return false + } + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} diff --git a/proxy/internal/proxy/trustedproxy_test.go b/proxy/internal/proxy/trustedproxy_test.go index 827b7babf..35ed1f5c2 100644 --- a/proxy/internal/proxy/trustedproxy_test.go +++ b/proxy/internal/proxy/trustedproxy_test.go @@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) { remoteAddr string xff string trusted []netip.Prefix - want string + want netip.Addr }{ { name: "empty trusted list returns RemoteAddr", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4", trusted: nil, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "untrusted RemoteAddr ignores XFF", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4, 10.0.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with single client in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr walks past trusted entries in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50, 10.0.0.2, 172.16.0.5", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", remoteAddr: "10.0.0.1:5000", xff: "", trusted: trusted, - want: "10.0.0.1", + want: netip.MustParseAddr("10.0.0.1"), }, { name: "all XFF IPs trusted returns leftmost", remoteAddr: "10.0.0.1:5000", xff: "10.0.0.2, 172.16.0.1, 10.0.0.3", trusted: trusted, - want: "10.0.0.2", + want: netip.MustParseAddr("10.0.0.2"), }, { name: "XFF with whitespace", remoteAddr: "10.0.0.1:5000", xff: " 203.0.113.50 , 10.0.0.2 ", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "XFF with empty segments", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50,,10.0.0.2", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "multi-hop with mixed trust", remoteAddr: "10.0.0.1:5000", xff: "8.8.8.8, 203.0.113.50, 172.16.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "RemoteAddr without port", remoteAddr: "10.0.0.1", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, } for _, tt := range tests { diff --git a/proxy/internal/restrict/restrict.go b/proxy/internal/restrict/restrict.go new file mode 100644 index 000000000..f3e0fa695 --- /dev/null +++ b/proxy/internal/restrict/restrict.go @@ -0,0 +1,315 @@ +// Package restrict provides connection-level access control based on +// IP CIDR ranges and geolocation (country codes). +package restrict + +import ( + "net/netip" + "slices" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/geolocation" +) + +// defaultLogger is used when no logger is provided to ParseFilter. +var defaultLogger = log.NewEntry(log.StandardLogger()) + +// GeoResolver resolves an IP address to geographic information. +type GeoResolver interface { + LookupAddr(addr netip.Addr) geolocation.Result + Available() bool +} + +// DecisionType is the type of CrowdSec remediation action. +type DecisionType string + +const ( + DecisionBan DecisionType = "ban" + DecisionCaptcha DecisionType = "captcha" + DecisionThrottle DecisionType = "throttle" +) + +// CrowdSecDecision holds the type of a CrowdSec decision. +type CrowdSecDecision struct { + Type DecisionType +} + +// CrowdSecChecker queries CrowdSec decisions for an IP address. +type CrowdSecChecker interface { + CheckIP(addr netip.Addr) *CrowdSecDecision + Ready() bool +} + +// CrowdSecMode is the per-service enforcement mode. +type CrowdSecMode string + +const ( + CrowdSecOff CrowdSecMode = "" + CrowdSecEnforce CrowdSecMode = "enforce" + CrowdSecObserve CrowdSecMode = "observe" +) + +// Filter evaluates IP restrictions. CIDR checks are performed first +// (cheap), followed by country lookups (more expensive) only when needed. +type Filter struct { + AllowedCIDRs []netip.Prefix + BlockedCIDRs []netip.Prefix + AllowedCountries []string + BlockedCountries []string + CrowdSec CrowdSecChecker + CrowdSecMode CrowdSecMode +} + +// FilterConfig holds the raw configuration for building a Filter. +type FilterConfig struct { + AllowedCIDRs []string + BlockedCIDRs []string + AllowedCountries []string + BlockedCountries []string + CrowdSec CrowdSecChecker + CrowdSecMode CrowdSecMode + Logger *log.Entry +} + +// ParseFilter builds a Filter from the config. Returns nil if no restrictions +// are configured. +func ParseFilter(cfg FilterConfig) *Filter { + hasCS := cfg.CrowdSecMode == CrowdSecEnforce || cfg.CrowdSecMode == CrowdSecObserve + if len(cfg.AllowedCIDRs) == 0 && len(cfg.BlockedCIDRs) == 0 && + len(cfg.AllowedCountries) == 0 && len(cfg.BlockedCountries) == 0 && !hasCS { + return nil + } + + logger := cfg.Logger + if logger == nil { + logger = defaultLogger + } + + f := &Filter{ + AllowedCountries: normalizeCountryCodes(cfg.AllowedCountries), + BlockedCountries: normalizeCountryCodes(cfg.BlockedCountries), + } + if hasCS { + f.CrowdSec = cfg.CrowdSec + f.CrowdSecMode = cfg.CrowdSecMode + } + for _, cidr := range cfg.AllowedCIDRs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + logger.Warnf("skip invalid allowed CIDR %q: %v", cidr, err) + continue + } + f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked()) + } + for _, cidr := range cfg.BlockedCIDRs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + logger.Warnf("skip invalid blocked CIDR %q: %v", cidr, err) + continue + } + f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked()) + } + return f +} + +func normalizeCountryCodes(codes []string) []string { + if len(codes) == 0 { + return nil + } + out := make([]string, len(codes)) + for i, c := range codes { + out[i] = strings.ToUpper(c) + } + return out +} + +// Verdict is the result of an access check. +type Verdict int + +const ( + // Allow indicates the address passed all checks. + Allow Verdict = iota + // DenyCIDR indicates the address was blocked by a CIDR rule. + DenyCIDR + // DenyCountry indicates the address was blocked by a country rule. + DenyCountry + // DenyGeoUnavailable indicates that country restrictions are configured + // but the geo lookup is unavailable. + DenyGeoUnavailable + // DenyCrowdSecBan indicates a CrowdSec "ban" decision. + DenyCrowdSecBan + // DenyCrowdSecCaptcha indicates a CrowdSec "captcha" decision. + DenyCrowdSecCaptcha + // DenyCrowdSecThrottle indicates a CrowdSec "throttle" decision. + DenyCrowdSecThrottle + // DenyCrowdSecUnavailable indicates enforce mode but the bouncer has not + // completed its initial sync. + DenyCrowdSecUnavailable +) + +// String returns the deny reason string matching the HTTP auth mechanism names. +func (v Verdict) String() string { + switch v { + case Allow: + return "allow" + case DenyCIDR: + return "ip_restricted" + case DenyCountry: + return "country_restricted" + case DenyGeoUnavailable: + return "geo_unavailable" + case DenyCrowdSecBan: + return "crowdsec_ban" + case DenyCrowdSecCaptcha: + return "crowdsec_captcha" + case DenyCrowdSecThrottle: + return "crowdsec_throttle" + case DenyCrowdSecUnavailable: + return "crowdsec_unavailable" + default: + return "unknown" + } +} + +// IsCrowdSec returns true when the verdict originates from a CrowdSec check. +func (v Verdict) IsCrowdSec() bool { + switch v { + case DenyCrowdSecBan, DenyCrowdSecCaptcha, DenyCrowdSecThrottle, DenyCrowdSecUnavailable: + return true + default: + return false + } +} + +// IsObserveOnly returns true when v is a CrowdSec verdict and the filter is in +// observe mode. Callers should log the verdict but not block the request. +func (f *Filter) IsObserveOnly(v Verdict) bool { + if f == nil { + return false + } + return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve +} + +// Check evaluates whether addr is permitted. CIDR rules are evaluated +// first because they are O(n) prefix comparisons. Country rules run +// only when CIDR checks pass and require a geo lookup. CrowdSec checks +// run last. +func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict { + if f == nil { + return Allow + } + + // Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that + // IPv4 CIDR rules match regardless of how the address was received. + addr = addr.Unmap() + + if v := f.checkCIDR(addr); v != Allow { + return v + } + if v := f.checkCountry(addr, geo); v != Allow { + return v + } + return f.checkCrowdSec(addr) +} + +func (f *Filter) checkCIDR(addr netip.Addr) Verdict { + if len(f.AllowedCIDRs) > 0 { + allowed := false + for _, prefix := range f.AllowedCIDRs { + if prefix.Contains(addr) { + allowed = true + break + } + } + if !allowed { + return DenyCIDR + } + } + + for _, prefix := range f.BlockedCIDRs { + if prefix.Contains(addr) { + return DenyCIDR + } + } + return Allow +} + +func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict { + if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 { + return Allow + } + + if geo == nil || !geo.Available() { + return DenyGeoUnavailable + } + + result := geo.LookupAddr(addr) + if result.CountryCode == "" { + // Unknown country: deny if an allowlist is active, allow otherwise. + // Blocklists are best-effort: unknown countries pass through since + // the default policy is allow. + if len(f.AllowedCountries) > 0 { + return DenyCountry + } + return Allow + } + + if len(f.AllowedCountries) > 0 { + if !slices.Contains(f.AllowedCountries, result.CountryCode) { + return DenyCountry + } + } + + if slices.Contains(f.BlockedCountries, result.CountryCode) { + return DenyCountry + } + + return Allow +} + +func (f *Filter) checkCrowdSec(addr netip.Addr) Verdict { + if f.CrowdSecMode == CrowdSecOff { + return Allow + } + + // Checker nil with enforce means CrowdSec was requested but the proxy + // has no LAPI configured. Fail-closed. + if f.CrowdSec == nil { + if f.CrowdSecMode == CrowdSecEnforce { + return DenyCrowdSecUnavailable + } + return Allow + } + + if !f.CrowdSec.Ready() { + if f.CrowdSecMode == CrowdSecEnforce { + return DenyCrowdSecUnavailable + } + return Allow + } + + d := f.CrowdSec.CheckIP(addr) + if d == nil { + return Allow + } + + switch d.Type { + case DecisionCaptcha: + return DenyCrowdSecCaptcha + case DecisionThrottle: + return DenyCrowdSecThrottle + default: + return DenyCrowdSecBan + } +} + +// HasRestrictions returns true if any restriction rules are configured. +func (f *Filter) HasRestrictions() bool { + if f == nil { + return false + } + return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 || + len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0 || + f.CrowdSecMode == CrowdSecEnforce || f.CrowdSecMode == CrowdSecObserve +} diff --git a/proxy/internal/restrict/restrict_test.go b/proxy/internal/restrict/restrict_test.go new file mode 100644 index 000000000..abaa1afdc --- /dev/null +++ b/proxy/internal/restrict/restrict_test.go @@ -0,0 +1,526 @@ +package restrict + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/proxy/internal/geolocation" +) + +type mockGeo struct { + countries map[string]string +} + +func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result { + return geolocation.Result{CountryCode: m.countries[addr.String()]} +} + +func (m *mockGeo) Available() bool { return true } + +func newMockGeo(entries map[string]string) *mockGeo { + return &mockGeo{countries: entries} +} + +func TestFilter_Check_NilFilter(t *testing.T) { + var f *Filter + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_Check_AllowedCIDR(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_BlockedCIDR(t *testing.T) { + f := ParseFilter(FilterConfig{BlockedCIDRs: []string{"10.0.0.0/8"}}) + + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCIDRs: []string{"10.1.0.0/16"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist") +} + +func TestFilter_Check_AllowedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist") +} + +func TestFilter_Check_BlockedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "CN", + "2.2.2.2": "RU", + "3.3.3.3": "US", + }) + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN", "RU"}}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist") +} + +func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + // Allow US and DE, but block DE explicitly. + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}, BlockedCountries: []string{"DE"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist") +} + +func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + }) + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active") +} + +func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "CN", + }) + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active") +} + +func TestFilter_Check_CountryWithoutGeo(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist") +} + +func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) { + f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist") +} + +func TestFilter_Check_GeoUnavailable(t *testing.T) { + geo := &unavailableGeo{} + + f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist") + + f2 := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}}) + assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist") +} + +func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + + // CIDR-only filter should never touch geo, so nil geo is fine. + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) { + geo := newMockGeo(map[string]string{ + "10.1.2.3": "CN", + "10.2.3.4": "US", + }) + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCountries: []string{"CN"}}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check") +} + +func TestParseFilter_Empty(t *testing.T) { + f := ParseFilter(FilterConfig{}) + assert.Nil(t, f) +} + +func TestParseFilter_InvalidCIDR(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"invalid", "10.0.0.0/8"}}) + + assert.NotNil(t, f) + assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped") + assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0]) +} + +func TestFilter_HasRestrictions(t *testing.T) { + assert.False(t, (*Filter)(nil).HasRestrictions()) + assert.False(t, (&Filter{}).HasRestrictions()) + assert.True(t, ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}).HasRestrictions()) + assert.True(t, ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}).HasRestrictions()) +} + +func TestFilter_Check_IPv6CIDR(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"2001:db8::/32"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist") +} + +func TestFilter_Check_IPv4MappedIPv6(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + + // A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR. + v4mapped := netip.MustParseAddr("::ffff:10.1.2.3") + assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6") + assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap") + + v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1") + assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR") +} + +func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8", "2001:db8::/32"}}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR") +} + +func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) { + // 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24. + f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"1.1.1.1/24"}}) + assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0]) + + // Verify it still matches correctly. + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil)) +} + +func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + + tests := []struct { + name string + allowedCountries []string + blockedCountries []string + addr string + want Verdict + }{ + { + name: "lowercase allowlist matches uppercase MaxMind code", + allowedCountries: []string{"us", "de"}, + addr: "1.1.1.1", + want: Allow, + }, + { + name: "mixed-case allowlist matches", + allowedCountries: []string{"Us", "dE"}, + addr: "2.2.2.2", + want: Allow, + }, + { + name: "lowercase allowlist rejects non-matching country", + allowedCountries: []string{"us", "de"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "lowercase blocklist blocks matching country", + blockedCountries: []string{"cn"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "mixed-case blocklist blocks matching country", + blockedCountries: []string{"Cn"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "lowercase blocklist does not block non-matching country", + blockedCountries: []string{"cn"}, + addr: "1.1.1.1", + want: Allow, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := ParseFilter(FilterConfig{AllowedCountries: tc.allowedCountries, BlockedCountries: tc.blockedCountries}) + got := f.Check(netip.MustParseAddr(tc.addr), geo) + assert.Equal(t, tc.want, got) + }) + } +} + +// unavailableGeo simulates a GeoResolver whose database is not loaded. +type unavailableGeo struct{} + +func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} } +func (u *unavailableGeo) Available() bool { return false } + +// mockCrowdSec is a test implementation of CrowdSecChecker. +type mockCrowdSec struct { + decisions map[string]*CrowdSecDecision + ready bool +} + +func (m *mockCrowdSec) CheckIP(addr netip.Addr) *CrowdSecDecision { + return m.decisions[addr.Unmap().String()] +} + +func (m *mockCrowdSec) Ready() bool { return m.ready } + +func TestFilter_CrowdSec_Enforce_Ban(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecBan, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("5.6.7.8"), nil)) +} + +func TestFilter_CrowdSec_Enforce_Captcha(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionCaptcha}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecCaptcha, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Enforce_Throttle(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionThrottle}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecThrottle, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_DoesNotBlock(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve}) + + verdict := f.Check(netip.MustParseAddr("1.2.3.4"), nil) + assert.Equal(t, DenyCrowdSecBan, verdict, "verdict should be ban") + assert.True(t, f.IsObserveOnly(verdict), "should be observe-only") +} + +func TestFilter_CrowdSec_Enforce_NotReady(t *testing.T) { + cs := &mockCrowdSec{ready: false} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_NotReady_Allows(t *testing.T) { + cs := &mockCrowdSec{ready: false} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Off(t *testing.T) { + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}}, + ready: true, + } + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecOff}) + + // CrowdSecOff means the filter is nil (no restrictions). + assert.Nil(t, f) +} + +func TestFilter_IsObserveOnly(t *testing.T) { + f := &Filter{CrowdSecMode: CrowdSecObserve} + assert.True(t, f.IsObserveOnly(DenyCrowdSecBan)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecCaptcha)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecThrottle)) + assert.True(t, f.IsObserveOnly(DenyCrowdSecUnavailable)) + assert.False(t, f.IsObserveOnly(DenyCIDR)) + assert.False(t, f.IsObserveOnly(Allow)) + + f2 := &Filter{CrowdSecMode: CrowdSecEnforce} + assert.False(t, f2.IsObserveOnly(DenyCrowdSecBan)) +} + +// TestFilter_LayerInteraction exercises the evaluation order across all three +// restriction layers: CIDR -> Country -> CrowdSec. Each layer can only further +// restrict; no layer can relax a denial from an earlier layer. +// +// Layer order | Behavior +// ---------------|------------------------------------------------------- +// 1. CIDR | Allowlist narrows to specific ranges, blocklist removes +// | specific ranges. Deny here → stop, CrowdSec never runs. +// 2. Country | Allowlist/blocklist by geo. Deny here → stop. +// 3. CrowdSec | IP reputation. Can block IPs that passed layers 1-2. +// | Observe mode: verdict returned but caller doesn't block. +func TestFilter_LayerInteraction(t *testing.T) { + bannedIP := "10.1.2.3" + cleanIP := "10.2.3.4" + outsideIP := "192.168.1.1" + + cs := &mockCrowdSec{ + decisions: map[string]*CrowdSecDecision{bannedIP: {Type: DecisionBan}}, + ready: true, + } + geo := newMockGeo(map[string]string{ + bannedIP: "US", + cleanIP: "US", + outsideIP: "CN", + }) + + tests := []struct { + name string + config FilterConfig + addr string + want Verdict + }{ + // CIDR allowlist + CrowdSec enforce: CrowdSec blocks inside allowed range + { + name: "allowed CIDR + CrowdSec banned", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCrowdSecBan, + }, + { + name: "allowed CIDR + CrowdSec clean", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: cleanIP, + want: Allow, + }, + { + name: "CIDR deny stops before CrowdSec", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: outsideIP, + want: DenyCIDR, + }, + + // CIDR blocklist + CrowdSec enforce: blocklist blocks first, CrowdSec blocks remaining + { + name: "blocked CIDR stops before CrowdSec", + config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCIDR, + }, + { + name: "not in blocklist + CrowdSec clean", + config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: cleanIP, + want: Allow, + }, + + // Country allowlist + CrowdSec enforce + { + name: "allowed country + CrowdSec banned", + config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: bannedIP, + want: DenyCrowdSecBan, + }, + { + name: "country deny stops before CrowdSec", + config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}, + addr: outsideIP, + want: DenyCountry, + }, + + // All three layers: CIDR allowlist + country blocklist + CrowdSec + { + name: "all layers: CIDR allow + country allow + CrowdSec ban", + config: FilterConfig{ + AllowedCIDRs: []string{"10.0.0.0/8"}, + BlockedCountries: []string{"CN"}, + CrowdSec: cs, + CrowdSecMode: CrowdSecEnforce, + }, + addr: bannedIP, // 10.x (CIDR ok), US (country ok), banned (CrowdSec deny) + want: DenyCrowdSecBan, + }, + { + name: "all layers: CIDR deny short-circuits everything", + config: FilterConfig{ + AllowedCIDRs: []string{"10.0.0.0/8"}, + BlockedCountries: []string{"CN"}, + CrowdSec: cs, + CrowdSecMode: CrowdSecEnforce, + }, + addr: outsideIP, // 192.x (CIDR deny) + want: DenyCIDR, + }, + + // Observe mode: verdict returned but IsObserveOnly is true + { + name: "observe mode: CrowdSec banned inside allowed CIDR", + config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecObserve}, + addr: bannedIP, + want: DenyCrowdSecBan, // verdict is ban, caller checks IsObserveOnly + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := ParseFilter(tc.config) + got := f.Check(netip.MustParseAddr(tc.addr), geo) + assert.Equal(t, tc.want, got) + + // Verify observe mode flag when applicable. + if tc.config.CrowdSecMode == CrowdSecObserve && got.IsCrowdSec() { + assert.True(t, f.IsObserveOnly(got), "observe mode verdict should be observe-only") + } + if tc.config.CrowdSecMode == CrowdSecEnforce && got.IsCrowdSec() { + assert.False(t, f.IsObserveOnly(got), "enforce mode verdict should not be observe-only") + } + }) + } +} + +func TestFilter_CrowdSec_Enforce_NilChecker(t *testing.T) { + // LAPI not configured: checker is nil but mode is enforce. Must fail closed. + f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce}) + + assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) { + // LAPI not configured: checker is nil but mode is observe. Must allow. + f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecObserve}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_HasRestrictions_CrowdSec(t *testing.T) { + cs := &mockCrowdSec{ready: true} + f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) + assert.True(t, f.HasRestrictions()) + + // Enforce mode without checker (LAPI not configured): still has restrictions + // because Check() will fail-closed with DenyCrowdSecUnavailable. + f2 := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce}) + assert.True(t, f2.HasRestrictions()) +} diff --git a/proxy/internal/roundtrip/context_test.go b/proxy/internal/roundtrip/context_test.go new file mode 100644 index 000000000..c4e8267f8 --- /dev/null +++ b/proxy/internal/roundtrip/context_test.go @@ -0,0 +1,32 @@ +package roundtrip + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestAccountIDContext(t *testing.T) { + t.Run("returns empty when missing", func(t *testing.T) { + assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background())) + }) + + t.Run("round-trips value", func(t *testing.T) { + ctx := WithAccountID(context.Background(), "acc-123") + assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx)) + }) +} + +func TestSkipTLSVerifyContext(t *testing.T) { + t.Run("false by default", func(t *testing.T) { + assert.False(t, skipTLSVerifyFromContext(context.Background())) + }) + + t.Run("true when set", func(t *testing.T) { + ctx := WithSkipTLSVerify(context.Background()) + assert.True(t, skipTLSVerifyFromContext(ctx)) + }) +} diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index d7fd2746f..e38e3dc4e 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -2,8 +2,10 @@ package roundtrip import ( "context" + "crypto/tls" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -13,11 +15,12 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/embed" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -25,7 +28,22 @@ import ( const deviceNamePrefix = "ingress-proxy-" // backendKey identifies a backend by its host:port from the target URL. -type backendKey = string +type backendKey string + +// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service) +// that holds a reference to an embedded NetBird client. Callers should use the +// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions. +type ServiceKey string + +// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service. +func DomainServiceKey(domain string) ServiceKey { + return ServiceKey("domain:" + domain) +} + +// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP). +func L4ServiceKey(id types.ServiceID) ServiceKey { + return ServiceKey("l4:" + id) +} var ( // ErrNoAccountID is returned when a request context is missing the account ID. @@ -38,23 +56,26 @@ var ( ErrTooManyInflight = errors.New("too many in-flight requests") ) -// domainInfo holds metadata about a registered domain. -type domainInfo struct { - serviceID string +// serviceInfo holds metadata about a registered service. +type serviceInfo struct { + serviceID types.ServiceID } -type domainNotification struct { - domain domain.Domain - serviceID string +type serviceNotification struct { + key ServiceKey + serviceID types.ServiceID } -// clientEntry holds an embedded NetBird client and tracks which domains use it. +// clientEntry holds an embedded NetBird client and tracks which services use it. type clientEntry struct { client *embed.Client transport *http.Transport - domains map[domain.Domain]domainInfo - createdAt time.Time - started bool + // insecureTransport is a clone of transport with TLS verification disabled, + // used when per-target skip_tls_verify is set. + insecureTransport *http.Transport + services map[ServiceKey]serviceInfo + createdAt time.Time + started bool // Per-backend in-flight limiting keyed by target host:port. // TODO: clean up stale entries when backend targets change. inflightMu sync.Mutex @@ -86,8 +107,15 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo } } +// ClientConfig holds configuration for the embedded NetBird client. +type ClientConfig struct { + MgmtAddr string + WGPort uint16 + PreSharedKey string +} + type statusNotifier interface { - NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error + NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error } type managementClient interface { @@ -96,12 +124,11 @@ type managementClient interface { // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. -// Clients are keyed by AccountID, allowing multiple domains to share the same connection. +// Clients are keyed by AccountID, allowing multiple services to share the same connection. type NetBird struct { - mgmtAddr string proxyID string proxyAddr string - wgPort int + clientCfg ClientConfig logger *log.Logger mgmtClient managementClient transportCfg transportConfig @@ -114,47 +141,50 @@ type NetBird struct { // ClientDebugInfo contains debug information about a client. type ClientDebugInfo struct { - AccountID types.AccountID - DomainCount int - Domains domain.List - HasClient bool - CreatedAt time.Time + AccountID types.AccountID + ServiceCount int + ServiceKeys []string + HasClient bool + CreatedAt time.Time } // accountIDContextKey is the context key for storing the account ID. type accountIDContextKey struct{} -// AddPeer registers a domain for an account. If the account doesn't have a client yet, +// skipTLSVerifyContextKey is the context key for requesting insecure TLS. +type skipTLSVerifyContextKey struct{} + +// AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. -// Multiple domains can share the same client. -func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error { +// Multiple services can share the same client. +func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { + si := serviceInfo{serviceID: serviceID} + n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { - // Client already exists for this account, just register the domain - entry.domains[d] = domainInfo{serviceID: serviceID} + entry.services[key] = si started := entry.started n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("registered domain with existing client") + "account_id": accountID, + "service_key": key, + }).Debug("registered service with existing client") - // If client is already started, notify this domain as connected immediately if started && n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).WithError(err).Warn("failed to notify status for existing client") } } return nil } - entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID) + entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if err != nil { n.clientsMux.Unlock() return err @@ -164,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).Info("created new client for account") // Attempt to start the client in the background; if this fails we will @@ -177,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma // createClientEntry generates a WireGuard keypair, authenticates with management, // and creates an embedded NetBird client. Must be called with clientsMux held. -func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) { +func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { + serviceID := si.serviceID n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, @@ -196,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }).Debug("authenticating new proxy peer with management") resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ - ServiceId: serviceID, + ServiceId: string(serviceID), AccountId: string(accountID), Token: authToken, WireguardPublicKey: publicKey.String(), @@ -227,13 +258,15 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. + wgPort := int(n.clientCfg.WGPort) client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, - ManagementURL: n.mgmtAddr, + ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.wgPort, + WireguardPort: &wgPort, + PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) @@ -242,31 +275,37 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create a transport using the client dialer. We do this instead of using // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. + transport := &http.Transport{ + DialContext: dialWithTimeout(client.DialContext), + ForceAttemptHTTP2: true, + MaxIdleConns: n.transportCfg.maxIdleConns, + MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, + MaxConnsPerHost: n.transportCfg.maxConnsPerHost, + IdleConnTimeout: n.transportCfg.idleConnTimeout, + TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout, + ExpectContinueTimeout: n.transportCfg.expectContinueTimeout, + ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout, + WriteBufferSize: n.transportCfg.writeBufferSize, + ReadBufferSize: n.transportCfg.readBufferSize, + DisableCompression: n.transportCfg.disableCompression, + } + + insecureTransport := transport.Clone() + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec + return &clientEntry{ - client: client, - domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}}, - transport: &http.Transport{ - DialContext: client.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: n.transportCfg.maxIdleConns, - MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, - MaxConnsPerHost: n.transportCfg.maxConnsPerHost, - IdleConnTimeout: n.transportCfg.idleConnTimeout, - TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout, - ExpectContinueTimeout: n.transportCfg.expectContinueTimeout, - ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout, - WriteBufferSize: n.transportCfg.writeBufferSize, - ReadBufferSize: n.transportCfg.readBufferSize, - DisableCompression: n.transportCfg.disableCompression, - }, - createdAt: time.Now(), - started: false, - inflightMap: make(map[backendKey]chan struct{}), - maxInflight: n.transportCfg.maxInflight, + client: client, + services: map[ServiceKey]serviceInfo{key: si}, + transport: transport, + insecureTransport: insecureTransport, + createdAt: time.Now(), + started: false, + inflightMap: make(map[backendKey]chan struct{}), + maxInflight: n.transportCfg.maxInflight, }, nil } -// runClientStartup starts the client and notifies registered domains on success. +// runClientStartup starts the client and notifies registered services on success. func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -280,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI return } - // Mark client as started and collect domains to notify outside the lock. + // Mark client as started and collect services to notify outside the lock. n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.started = true } - var domainsToNotify []domainNotification + var toNotify []serviceNotification if exists { - for dom, info := range entry.domains { - domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID}) + for key, info := range entry.services { + toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } n.clientsMux.Unlock() @@ -297,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI if n.statusNotifier == nil { return } - for _, dn := range domainsToNotify { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil { + for _, sn := range toNotify { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).WithError(err).Warn("failed to notify tunnel connection status") } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).Info("notified management about tunnel connection") } } } -// RemovePeer unregisters a domain from an account. The client is only stopped -// when no domains are using it anymore. -func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error { +// RemovePeer unregisters a service from an account. The client is only stopped +// when no services are using it anymore. +func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error { n.clientsMux.Lock() entry, exists := n.clients[accountID] @@ -324,72 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d return nil } - // Get domain info before deleting - domInfo, domainExists := entry.domains[d] - if !domainExists { + si, svcExists := entry.services[key] + if !svcExists { n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("remove peer: domain not registered") + "account_id": accountID, + "service_key": key, + }).Debug("remove peer: service not registered") return nil } - delete(entry.domains, d) - - // If there are still domains using this client, keep it running - if len(entry.domains) > 0 { - n.clientsMux.Unlock() + delete(entry.services, key) + stopClient := len(entry.services) == 0 + var client *embed.Client + var transport, insecureTransport *http.Transport + if stopClient { + n.logger.WithField("account_id", accountID).Info("stopping client, no more services") + client = entry.client + transport = entry.transport + insecureTransport = entry.insecureTransport + delete(n.clients, accountID) + } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - "remaining_domains": len(entry.domains), - }).Debug("unregistered domain, client still in use") - - // Notify this domain as disconnected - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") - } - } - return nil + "account_id": accountID, + "service_key": key, + "remaining_services": len(entry.services), + }).Debug("unregistered service, client still in use") } - - // No more domains using this client, stop it - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).Info("stopping client, no more domains") - - client := entry.client - transport := entry.transport - delete(n.clients, accountID) n.clientsMux.Unlock() - // Notify disconnection before stopping - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") + n.notifyDisconnect(ctx, accountID, key, si.serviceID) + + if stopClient { + transport.CloseIdleConnections() + insecureTransport.CloseIdleConnections() + if err := client.Stop(ctx); err != nil { + n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client") } } - transport.CloseIdleConnections() - - if err := client.Stop(ctx); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).WithError(err).Warn("failed to stop netbird client") - } - return nil } +func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) { + if n.statusNotifier == nil { + return + } + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil { + if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound { + n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "service_key": key, + }).WithError(err).Warn("failed to notify tunnel disconnection status") + } + } +} + // RoundTrip implements http.RoundTripper. It looks up the client for the account // specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { @@ -408,9 +440,12 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { } client := entry.client transport := entry.transport + if skipTLSVerifyFromContext(req.Context()) { + transport = entry.insecureTransport + } n.clientsMux.RUnlock() - release, ok := entry.acquireInflight(req.URL.Host) + release, ok := entry.acquireInflight(backendKey(req.URL.Host)) defer release() if !ok { return nil, ErrTooManyInflight @@ -450,6 +485,7 @@ func (n *NetBird) StopAll(ctx context.Context) error { var merr *multierror.Error for accountID, entry := range n.clients { entry.transport.CloseIdleConnections() + entry.insecureTransport.CloseIdleConnections() if err := entry.client.Stop(ctx); err != nil { n.logger.WithFields(log.Fields{ "account_id": accountID, @@ -470,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool { return exists } -// DomainCount returns the number of domains registered for the given account. +// ServiceCount returns the number of services registered for the given account. // Returns 0 if the account has no client. -func (n *NetBird) DomainCount(accountID types.AccountID) int { +func (n *NetBird) ServiceCount(accountID types.AccountID) int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return 0 } - return len(entry.domains) + return len(entry.services) } // ClientCount returns the total number of active clients. @@ -507,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { result := make(map[types.AccountID]ClientDebugInfo) for accountID, entry := range n.clients { - domains := make(domain.List, 0, len(entry.domains)) - for d := range entry.domains { - domains = append(domains, d) + keys := make([]string, 0, len(entry.services)) + for k := range entry.services { + keys = append(keys, string(k)) } result[accountID] = ClientDebugInfo{ - AccountID: accountID, - DomainCount: len(entry.domains), - Domains: domains, - HasClient: entry.client != nil, - CreatedAt: entry.createdAt, + AccountID: accountID, + ServiceCount: len(entry.services), + ServiceKeys: keys, + HasClient: entry.client != nil, + CreatedAt: entry.createdAt, } } return result @@ -536,18 +572,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client { return result } -// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random +// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random // OS-assigned port. A fixed port only works with single-account deployments; // multiple accounts will fail to bind the same port. -func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { +func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { if logger == nil { logger = log.StandardLogger() } return &NetBird{ - mgmtAddr: mgmtAddr, proxyID: proxyID, proxyAddr: proxyAddr, - wgPort: wgPort, + clientCfg: clientCfg, logger: logger, clients: make(map[types.AccountID]*clientEntry), statusNotifier: notifier, @@ -556,6 +591,20 @@ func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Log } } +// dialWithTimeout wraps a DialContext function so that any dial timeout +// stored in the context (via types.WithDialTimeout) is applied only to +// the connection establishment phase, not the full request lifetime. +func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if d, ok := types.DialTimeoutFromContext(ctx); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + return dial(ctx, network, addr) + } +} + // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { return context.WithValue(ctx, accountIDContextKey{}, accountID) @@ -573,3 +622,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID { } return accountID } + +// WithSkipTLSVerify marks the context to use an insecure transport that skips +// TLS certificate verification for the backend connection. +func WithSkipTLSVerify(ctx context.Context) context.Context { + return context.WithValue(ctx, skipTLSVerifyContextKey{}, true) +} + +func skipTLSVerifyFromContext(ctx context.Context) bool { + v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool) + return v +} diff --git a/proxy/internal/roundtrip/netbird_bench_test.go b/proxy/internal/roundtrip/netbird_bench_test.go index e89213c33..330ea0332 100644 --- a/proxy/internal/roundtrip/netbird_bench_test.go +++ b/proxy/internal/roundtrip/netbird_bench_test.go @@ -1,6 +1,7 @@ package roundtrip import ( + "context" "crypto/rand" "math/big" "sync" @@ -8,7 +9,6 @@ import ( "time" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" ) // Simple benchmark for comparison with AddPeer contention. @@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } // Launch workers that continuously call AddPeer with new random accountIDs. + ctx, cancel := context.WithCancel(b.Context()) var wg sync.WaitGroup for range addPeerWorkers { - wg.Go(func() { - for { - if err := nb.AddPeer(b.Context(), + wg.Add(1) + go func() { + defer wg.Done() + for ctx.Err() == nil { + if err := nb.AddPeer(ctx, types.AccountID(rand.Text()), - domain.Domain(rand.Text()), + ServiceKey(rand.Text()), rand.Text(), - rand.Text()); err != nil { - b.Log(err) + types.ServiceID(rand.Text())); err != nil { + return } } - }) + }() } // Benchmark calling HasClient during AddPeer contention. @@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } }) b.StopTimer() + cancel() + wg.Wait() } diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 3e76af9da..5444f6c11 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -27,16 +26,15 @@ type mockStatusNotifier struct { } type statusCall struct { - accountID string - serviceID string - domain string + accountID types.AccountID + serviceID types.ServiceID connected bool } -func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error { +func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { m.mu.Lock() defer m.mu.Unlock() - m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected}) + m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected}) return nil } @@ -49,7 +47,11 @@ func (m *mockStatusNotifier) calls() []statusCall { // mockNetBird creates a NetBird instance for testing without actually connecting. // It uses an invalid management URL to prevent real connections. func mockNetBird() *NetBird { - return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{}) + return NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, nil, &mockMgmtClient{}) } func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { @@ -58,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { // Initially no client exists. assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") - // Add first domain - this should create a new client. - // Note: This will fail to actually connect since we use an invalid URL, - // but the client entry should still be created. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service - this should create a new client. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID), "should have client after AddPeer") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add first domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - assert.Equal(t, 1, nb.DomainCount(accountID)) + assert.Equal(t, 1, nb.ServiceCount(accountID)) - // Add second domain for the same account - should reuse existing client. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + // Add second service for the same account - should reuse existing client. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service") - // Add third domain. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + // Add third service. + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain") + assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service") // Still only one client. assert.True(t, nb.HasClient(accountID)) @@ -98,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) { account1 := types.AccountID("account-1") account2 := types.AccountID("account-2") - // Add domain for account 1. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add service for account 1. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Add domain for account 2. - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2") + // Add service for account 2. + err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2")) require.NoError(t, err) // Both accounts should have their own clients. assert.True(t, nb.HasClient(account1), "account1 should have client") assert.True(t, nb.HasClient(account2), "account2 should have client") - assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1") - assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1") } -func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) { +func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add multiple domains. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add multiple services. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID)) + assert.Equal(t, 3, nb.ServiceCount(accountID)) - // Remove one domain - client should remain. + // Remove one service - client should remain. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain") - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2") + assert.True(t, nb.HasClient(accountID), "client should remain after removing one service") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2") - // Remove another domain - client should still remain. + // Remove another service - client should still remain. err = nb.RemovePeer(context.Background(), accountID, "domain2.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.True(t, nb.HasClient(accountID), "client should remain after removing second service") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } -func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) { +func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add single domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add single service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) - // Remove the only domain - client should be removed. - // Note: Stop() may fail since the client never actually connected, - // but the entry should still be removed from the map. + // Remove the only service - client should be removed. _ = nb.RemovePeer(context.Background(), accountID, "domain1.test") - // After removing all domains, client should be gone. - assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + // After removing all services, client should be gone. + assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") } func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { @@ -167,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { assert.NoError(t, err, "removing from non-existent account should not error") } -func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) { +func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add one domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add one service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Remove non-existent domain - should not affect existing domain. - err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test")) + // Remove non-existent service - should not affect existing service. + err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test") require.NoError(t, err) - // Original domain should still be registered. + // Original service should still be registered. assert.True(t, nb.HasClient(accountID)) - assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain") + assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain") } func TestWithAccountID_AndAccountIDFromContext(t *testing.T) { @@ -212,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) { account2 := types.AccountID("account-2") account3 := types.AccountID("account-3") - // Add domains for multiple accounts. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1") + // Add services for multiple accounts. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3") + err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3")) require.NoError(t, err) assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients") // Stop all clients. - // Note: StopAll may return errors since clients never actually connected, - // but the clients should still be removed from the map. _ = nb.StopAll(context.Background()) assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll") @@ -239,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) { assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients") // Add clients for different accounts. - err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1") + err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.Equal(t, 1, nb.ClientCount()) - err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) assert.Equal(t, 2, nb.ClientCount()) - // Adding domain to existing account should not increase count. - err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b") + // Adding service to existing account should not increase count. + err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b")) require.NoError(t, err) - assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count") + assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count") } func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) { @@ -282,11 +278,15 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) { func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - // Add first domain — creates a new client entry. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + // Add first service — creates a new client entry. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) // Manually mark client as started to simulate background startup completing. @@ -294,35 +294,38 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { nb.clients[accountID].started = true nb.clientsMux.Unlock() - // Add second domain — should notify immediately since client is already started. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + // Add second service — should notify immediately since client is already started. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, string(accountID), calls[0].accountID) - assert.Equal(t, "svc-2", calls[0].serviceID) - assert.Equal(t, "domain2.test", calls[0].domain) + assert.Equal(t, accountID, calls[0].accountID) + assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID) assert.True(t, calls[0].connected) } func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) - // Remove one domain — client stays, but disconnection notification fires. + // Remove one service — client stays, but disconnection notification fires. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, "domain1.test", calls[0].domain) + assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID) assert.False(t, calls[0].connected) } diff --git a/proxy/internal/tcp/bench_test.go b/proxy/internal/tcp/bench_test.go new file mode 100644 index 000000000..049f8395d --- /dev/null +++ b/proxy/internal/tcp/bench_test.go @@ -0,0 +1,133 @@ +package tcp + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "testing" +) + +// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real +// TLS ClientHello and extracting the SNI. This is the per-connection cost +// added to every TLS connection on the main listener. +func BenchmarkPeekClientHello_TLS(b *testing.B) { + // Pre-generate a ClientHello by capturing what crypto/tls sends. + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + var hello []byte + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + hello = make([]byte, n) + copy(hello, buf[:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(hello) + conn := &readerConn{Reader: r} + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + // Simulate draining the peeked bytes (what the HTTP server would do). + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS +// connections that hit the fast non-handshake exit path. +func BenchmarkPeekClientHello_NonTLS(b *testing.B) { + httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(httpReq) + conn := &readerConn{Reader: r} + _, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn +// wrapper compared to a plain connection read. The peeked bytes use +// io.MultiReader which adds one indirection per Read call. +func BenchmarkPeekedConn_Read(b *testing.B) { + data := make([]byte, 4096) + peeked := make([]byte, 512) + buf := make([]byte, 1024) + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(data) + conn := &readerConn{Reader: r} + pc := newPeekedConn(conn, peeked) + for { + _, err := pc.Read(buf) + if err != nil { + break + } + } + } +} + +// BenchmarkExtractSNI measures just the in-memory SNI parsing cost, +// excluding I/O. +func BenchmarkExtractSNI(b *testing.B) { + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + payload := make([]byte, n-tlsRecordHeaderLen) + copy(payload, buf[tlsRecordHeaderLen:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + sni := extractSNI(payload) + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + } +} + +// readerConn wraps an io.Reader as a net.Conn for benchmarking. +// Only Read is functional; all other methods are no-ops. +type readerConn struct { + io.Reader + net.Conn +} + +func (c *readerConn) Read(b []byte) (int, error) { + return c.Reader.Read(b) +} diff --git a/proxy/internal/tcp/chanlistener.go b/proxy/internal/tcp/chanlistener.go new file mode 100644 index 000000000..ee64bc0a2 --- /dev/null +++ b/proxy/internal/tcp/chanlistener.go @@ -0,0 +1,76 @@ +package tcp + +import ( + "net" + "sync" +) + +// chanListener implements net.Listener by reading connections from a channel. +// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS. +type chanListener struct { + ch chan net.Conn + addr net.Addr + once sync.Once + closed chan struct{} +} + +func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener { + return &chanListener{ + ch: ch, + addr: addr, + closed: make(chan struct{}), + } +} + +// Accept waits for and returns the next connection from the channel. +func (l *chanListener) Accept() (net.Conn, error) { + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + return conn, nil + case <-l.closed: + // Drain buffered connections before returning. + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + _ = conn.Close() + default: + return nil, net.ErrClosed + } + } + } + } +} + +// Close signals the listener to stop accepting connections and drains +// any buffered connections that have not yet been accepted. +func (l *chanListener) Close() error { + l.once.Do(func() { + close(l.closed) + for { + select { + case conn, ok := <-l.ch: + if !ok { + return + } + _ = conn.Close() + default: + return + } + } + }) + return nil +} + +// Addr returns the listener's network address. +func (l *chanListener) Addr() net.Addr { + return l.addr +} + +var _ net.Listener = (*chanListener)(nil) diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go new file mode 100644 index 000000000..26f3e5c7c --- /dev/null +++ b/proxy/internal/tcp/peekedconn.go @@ -0,0 +1,39 @@ +package tcp + +import ( + "bytes" + "io" + "net" +) + +// peekedConn wraps a net.Conn and prepends previously peeked bytes +// so that readers see the full original stream transparently. +type peekedConn struct { + net.Conn + reader io.Reader +} + +func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn { + return &peekedConn{ + Conn: conn, + reader: io.MultiReader(bytes.NewReader(peeked), conn), + } +} + +// Read replays the peeked bytes first, then reads from the underlying conn. +func (c *peekedConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// CloseWrite delegates to the underlying connection if it supports +// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn +// as an interface hides the concrete type's CloseWrite method, making +// half-close a silent no-op for all SNI-routed connections. +func (c *peekedConn) CloseWrite() error { + if hc, ok := c.Conn.(halfCloser); ok { + return hc.CloseWrite() + } + return nil +} + +var _ halfCloser = (*peekedConn)(nil) diff --git a/proxy/internal/tcp/proxyprotocol.go b/proxy/internal/tcp/proxyprotocol.go new file mode 100644 index 000000000..699b75a5d --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol.go @@ -0,0 +1,29 @@ +package tcp + +import ( + "fmt" + "net" + + "github.com/pires/go-proxyproto" +) + +// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection, +// conveying the real client address. +func writeProxyProtoV2(client, backend net.Conn) error { + tp := proxyproto.TCPv4 + if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil { + tp = proxyproto.TCPv6 + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: tp, + SourceAddr: client.RemoteAddr(), + DestinationAddr: client.LocalAddr(), + } + if _, err := header.WriteTo(backend); err != nil { + return fmt.Errorf("write PROXY protocol v2 header: %w", err) + } + return nil +} diff --git a/proxy/internal/tcp/proxyprotocol_test.go b/proxy/internal/tcp/proxyprotocol_test.go new file mode 100644 index 000000000..f8c48b2ab --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol_test.go @@ -0,0 +1,128 @@ +package tcp + +import ( + "bufio" + "net" + "testing" + + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteProxyProtoV2_IPv4(t *testing.T) { + // Set up a real TCP listener and dial to get connections with real addresses. + ln, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp4", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + // Use a pipe as the backend: write the header to one end, read from the other. + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + // serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination. + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + // Read the PROXY protocol header from the backend read side. + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4") + + // serverConn.RemoteAddr() is the client's address (source in the header). + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + // serverConn.LocalAddr() is the server's address (destination in the header). + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} + +func TestWriteProxyProtoV2_IPv6(t *testing.T) { + // Set up a real TCP6 listener on loopback. + ln, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available:", err) + } + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp6", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6") + + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go new file mode 100644 index 000000000..39949818d --- /dev/null +++ b/proxy/internal/tcp/relay.go @@ -0,0 +1,156 @@ +package tcp + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +// errIdleTimeout is returned when a relay connection is closed due to inactivity. +var errIdleTimeout = errors.New("idle timeout") + +// DefaultIdleTimeout is the default idle timeout for TCP relay connections. +// A zero value disables idle timeout checking. +const DefaultIdleTimeout = 5 * time.Minute + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn). When one copy direction finishes, we signal +// EOF to the remote by closing the write side while keeping the read +// side open so the other direction can drain. +type halfCloser interface { + CloseWrite() error +} + +// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Relay copies data bidirectionally between src and dst until both +// sides are done or the context is canceled. When idleTimeout is +// non-zero, each direction's read is deadline-guarded; if no data +// flows within the timeout the connection is torn down. When one +// direction finishes, it half-closes the write side of the +// destination (if supported) to signal EOF, allowing the other +// direction to drain gracefully before the full connection teardown. +func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + _ = src.Close() + _ = dst.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var errSrcToDst, errDstToSrc error + + go func() { + defer wg.Done() + srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) + halfClose(dst) + cancel() + }() + + go func() { + defer wg.Done() + dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) + halfClose(src) + cancel() + }() + + wg.Wait() + + if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { + logger.Debug("relay closed due to idle timeout") + } + if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { + logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) + } + if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { + logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) + } + + return srcToDst, dstToSrc +} + +// copyWithIdleTimeout copies from src to dst using a pooled buffer. +// When idleTimeout > 0 it sets a read deadline on src before each +// read and treats a timeout as an idle-triggered close. +func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + if idleTimeout <= 0 { + return io.CopyBuffer(dst, src, *bufp) + } + + conn, ok := src.(net.Conn) + if !ok { + return io.CopyBuffer(dst, src, *bufp) + } + + buf := *bufp + var total int64 + for { + if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { + return total, err + } + nr, readErr := src.Read(buf) + if nr > 0 { + n, err := checkedWrite(dst, buf[:nr]) + total += n + if err != nil { + return total, err + } + } + if readErr != nil { + if netutil.IsTimeout(readErr) { + return total, errIdleTimeout + } + return total, readErr + } + } +} + +// checkedWrite writes buf to dst and returns the number of bytes written. +// It guards against short writes and negative counts per io.Copy convention. +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) +} + +// halfClose attempts to half-close the write side of the connection. +// If the connection does not support half-close, this is a no-op. +func halfClose(conn net.Conn) { + if hc, ok := conn.(halfCloser); ok { + // Best-effort; the full close will follow shortly. + _ = hc.CloseWrite() + } +} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go new file mode 100644 index 000000000..e42d65b9d --- /dev/null +++ b/proxy/internal/tcp/relay_test.go @@ -0,0 +1,210 @@ +package tcp + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +func TestRelay_BidirectionalCopy(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + srcData := []byte("hello from src") + dstData := []byte("hello from dst") + + // dst side: write response first, then read + close. + go func() { + _, _ = dstClient.Write(dstData) + buf := make([]byte, 256) + _, _ = dstClient.Read(buf) + dstClient.Close() + }() + + // src side: read the response, then send data + close. + go func() { + buf := make([]byte, 256) + _, _ = srcClient.Read(buf) + _, _ = srcClient.Write(srcData) + srcClient.Close() + }() + + s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + + assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") + assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") +} + +func TestRelay_ContextCancellation(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer srcClient.Close() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + // Cancel should cause Relay to return. + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after context cancellation") + } +} + +func TestRelay_OneSideClosed(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Close src immediately. Relay should complete without hanging. + srcClient.Close() + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after one side closed") + } +} + +func TestRelay_LargeTransfer(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // 1MB of data. + data := make([]byte, 1<<20) + for i := range data { + data[i] = byte(i % 256) + } + + go func() { + _, _ = srcClient.Write(data) + srcClient.Close() + }() + + errCh := make(chan error, 1) + go func() { + received, err := io.ReadAll(dstClient) + if err != nil { + errCh <- err + return + } + if len(received) != len(data) { + errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received)) + return + } + errCh <- nil + dstClient.Close() + }() + + s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") + require.NoError(t, <-errCh) +} + +func TestRelay_IdleTimeout(t *testing.T) { + // Use real TCP connections so SetReadDeadline works (net.Pipe + // does not support deadlines). + srcLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srcLn.Close() + + dstLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer dstLn.Close() + + srcClient, err := net.Dial("tcp", srcLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer srcClient.Close() + + srcServer, err := srcLn.Accept() + if err != nil { + t.Fatal(err) + } + + dstClient, err := net.Dial("tcp", dstLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer dstClient.Close() + + dstServer, err := dstLn.Accept() + if err != nil { + t.Fatal(err) + } + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Send initial data to prove the relay works. + go func() { + _, _ = srcClient.Write([]byte("ping")) + }() + + done := make(chan struct{}) + var s2d, d2s int64 + go func() { + s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + close(done) + }() + + // Read the forwarded data on the dst side. + buf := make([]byte, 64) + n, err := dstClient.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Now stop sending. The relay should close after the idle timeout. + select { + case <-done: + assert.Greater(t, s2d, int64(0), "should have transferred initial data") + _ = d2s + case <-time.After(5 * time.Second): + t.Fatal("Relay did not exit after idle timeout") + } +} + +func TestIsExpectedError(t *testing.T) { + assert.True(t, netutil.IsExpectedError(net.ErrClosed)) + assert.True(t, netutil.IsExpectedError(context.Canceled)) + assert.True(t, netutil.IsExpectedError(io.EOF)) + assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF)) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go new file mode 100644 index 000000000..9f8660aeb --- /dev/null +++ b/proxy/internal/tcp/router.go @@ -0,0 +1,671 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// defaultDialTimeout is the fallback dial timeout when no per-route +// timeout is configured. +const defaultDialTimeout = 30 * time.Second + +// errAccessRestricted is returned by relayTCP for access restriction +// denials so callers can skip warn-level logging (already logged at debug). +var errAccessRestricted = errors.New("rejected by access restrictions") + +// SNIHost is a typed key for SNI hostname lookups. +type SNIHost string + +// RouteType specifies how a connection should be handled. +type RouteType int + +const ( + // RouteHTTP routes the connection through the HTTP reverse proxy. + RouteHTTP RouteType = iota + // RouteTCP relays the connection directly to the backend (TLS passthrough). + RouteTCP +) + +const ( + // sniPeekTimeout is the deadline for reading the TLS ClientHello. + sniPeekTimeout = 5 * time.Second + // DefaultDrainTimeout is the default grace period for in-flight relay + // connections to finish during shutdown. + DefaultDrainTimeout = 30 * time.Second + // DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router. + DefaultMaxRelayConns = 4096 + // httpChannelBuffer is the capacity of the channel feeding HTTP connections. + httpChannelBuffer = 4096 +) + +// DialResolver returns a DialContextFunc for the given account. +type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error) + +// Route describes where a connection for a given SNI should be sent. +type Route struct { + Type RouteType + AccountID types.AccountID + ServiceID types.ServiceID + // Domain is the service's configured domain, used for access log entries. + Domain string + // Protocol is the frontend protocol (tcp, tls), used for access log entries. + Protocol accesslog.Protocol + // Target is the backend address for TCP relay (e.g. "10.0.0.5:5432"). + Target string + // ProxyProtocol enables sending a PROXY protocol v2 header to the backend. + ProxyProtocol bool + // DialTimeout overrides the default dial timeout for this route. + // Zero uses defaultDialTimeout. + DialTimeout time.Duration + // SessionIdleTimeout overrides the default idle timeout for relay connections. + // Zero uses DefaultIdleTimeout. + SessionIdleTimeout time.Duration + // Filter holds connection-level IP/geo restrictions. Nil means no restrictions. + Filter *restrict.Filter +} + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// RelayObserver receives callbacks for TCP relay lifecycle events. +// All methods must be safe for concurrent use. +type RelayObserver interface { + TCPRelayStarted(accountID types.AccountID) + TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) + TCPRelayDialError(accountID types.AccountID) + TCPRelayRejected(accountID types.AccountID) +} + +// Router accepts raw TCP connections on a shared listener, peeks at +// the TLS ClientHello to extract the SNI, and routes the connection +// to either the HTTP reverse proxy or a direct TCP relay. +type Router struct { + logger *log.Logger + // httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter. + httpCh chan net.Conn + httpListener *chanListener + mu sync.RWMutex + routes map[SNIHost][]Route + fallback *Route + draining bool + dialResolve DialResolver + activeConns sync.WaitGroup + activeRelays sync.WaitGroup + relaySem chan struct{} + drainDone chan struct{} + observer RelayObserver + accessLog l4Logger + geo restrict.GeoResolver + // svcCtxs tracks a context per service ID. All relay goroutines for a + // service derive from its context; canceling it kills them immediately. + svcCtxs map[types.ServiceID]context.Context + svcCancels map[types.ServiceID]context.CancelFunc +} + +// NewRouter creates a new SNI-based connection router. +func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router { + httpCh := make(chan net.Conn, httpChannelBuffer) + return &Router{ + logger: logger, + httpCh: httpCh, + httpListener: newChanListener(httpCh, addr), + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// NewPortRouter creates a Router for a dedicated port without an HTTP +// channel. Connections that don't match any SNI route fall through to +// the fallback relay (if set) or are closed. +func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router { + return &Router{ + logger: logger, + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// HTTPListener returns a net.Listener that yields connections routed +// to the HTTP handler. Use this with http.Server.ServeTLS. +func (r *Router) HTTPListener() net.Listener { + return r.httpListener +} + +// AddRoute registers an SNI route. Multiple routes for the same host are +// stored and resolved by priority at lookup time (HTTP > TCP). +// Empty host is ignored to prevent conflicts with ECH/ESNI fallback. +func (r *Router) AddRoute(host SNIHost, route Route) { + host = SNIHost(strings.ToLower(string(host))) + if host == "" { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + routes := r.routes[host] + for i, existing := range routes { + if existing.ServiceID == route.ServiceID { + r.cancelServiceLocked(route.ServiceID) + routes[i] = route + return + } + } + r.routes[host] = append(routes, route) +} + +// RemoveRoute removes the route for the given host and service ID. +// Active relay connections for the service are closed immediately. +// If other routes remain for the host, they are preserved. +func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { + host = SNIHost(strings.ToLower(string(host))) + + r.mu.Lock() + defer r.mu.Unlock() + + r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool { + return route.ServiceID == svcID + }) + if len(r.routes[host]) == 0 { + delete(r.routes, host) + } + r.cancelServiceLocked(svcID) +} + +// SetFallback registers a catch-all route for connections that don't +// match any SNI route. On a port router this handles plain TCP relay; +// on the main router it takes priority over the HTTP channel. +func (r *Router) SetFallback(route Route) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = &route +} + +// RemoveFallback clears the catch-all fallback route and closes any +// active relay connections for the given service. +func (r *Router) RemoveFallback(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = nil + r.cancelServiceLocked(svcID) +} + +// SetObserver sets the relay lifecycle observer. Must be called before Serve. +func (r *Router) SetObserver(obs RelayObserver) { + r.mu.Lock() + defer r.mu.Unlock() + r.observer = obs +} + +// SetAccessLogger sets the L4 access logger. Must be called before Serve. +func (r *Router) SetAccessLogger(l l4Logger) { + r.mu.Lock() + defer r.mu.Unlock() + r.accessLog = l +} + +// getObserver returns the current relay observer under the read lock. +func (r *Router) getObserver() RelayObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + +// IsEmpty returns true when the router has no SNI routes and no fallback. +func (r *Router) IsEmpty() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.fallback == nil +} + +// Serve accepts connections from ln and routes them based on SNI. +// It blocks until ctx is canceled or ln is closed, then drains +// active relay connections up to DefaultDrainTimeout. +func (r *Router) Serve(ctx context.Context, ln net.Listener) error { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + _ = ln.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + case <-done: + } + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ok := r.Drain(DefaultDrainTimeout); !ok { + r.logger.Warn("timed out waiting for connections to drain") + } + return nil + } + r.logger.Debugf("SNI router accept: %v", err) + continue + } + r.activeConns.Add(1) + go func() { + defer r.activeConns.Done() + r.handleConn(ctx, conn) + }() + } +} + +// handleConn peeks at the TLS ClientHello and routes the connection. +func (r *Router) handleConn(ctx context.Context, conn net.Conn) { + // Fast path: when no SNI routes and no HTTP channel exist (pure TCP + // fallback port), skip the TLS peek entirely to avoid read errors on + // non-TLS connections and reduce latency. + if r.isFallbackOnly() { + r.handleUnmatched(ctx, conn) + return + } + + if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil { + r.logger.Debugf("set SNI peek deadline: %v", err) + _ = conn.Close() + return + } + + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + r.logger.Debugf("SNI peek: %v", err) + if wrapped != nil { + r.handleUnmatched(ctx, wrapped) + } else { + _ = conn.Close() + } + return + } + + if err := wrapped.SetReadDeadline(time.Time{}); err != nil { + r.logger.Debugf("clear SNI peek deadline: %v", err) + _ = wrapped.Close() + return + } + + host := SNIHost(strings.ToLower(sni)) + route, ok := r.lookupRoute(host) + if !ok { + r.handleUnmatched(ctx, wrapped) + return + } + + if route.Type == RouteHTTP { + r.sendToHTTP(wrapped) + return + } + + if err := r.relayTCP(ctx, wrapped, host, route); err != nil { + if !errors.Is(err, errAccessRestricted) { + r.logger.WithFields(log.Fields{ + "sni": host, + "service_id": route.ServiceID, + "target": route.Target, + }).Warnf("TCP relay: %v", err) + } + _ = wrapped.Close() + } +} + +// isFallbackOnly returns true when the router has no SNI routes and no HTTP +// channel, meaning all connections should go directly to the fallback relay. +func (r *Router) isFallbackOnly() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.httpCh == nil +} + +// handleUnmatched routes a connection that didn't match any SNI route. +// This includes ECH/ESNI connections where the cleartext SNI is empty. +// It tries the fallback relay first, then the HTTP channel, and closes +// the connection if neither is available. +func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { + r.mu.RLock() + fb := r.fallback + r.mu.RUnlock() + + if fb != nil { + if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { + if !errors.Is(err, errAccessRestricted) { + r.logger.WithFields(log.Fields{ + "service_id": fb.ServiceID, + "target": fb.Target, + }).Warnf("TCP relay (fallback): %v", err) + } + _ = conn.Close() + } + return + } + r.sendToHTTP(conn) +} + +// lookupRoute returns the highest-priority route for the given SNI host. +// HTTP routes take precedence over TCP routes. +func (r *Router) lookupRoute(host SNIHost) (Route, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + routes, ok := r.routes[host] + if !ok || len(routes) == 0 { + return Route{}, false + } + best := routes[0] + for _, route := range routes[1:] { + if route.Type < best.Type { + best = route + } + } + return best, true +} + +// sendToHTTP feeds the connection to the HTTP handler via the channel. +// If no HTTP channel is configured (port router), the router is +// draining, or the channel is full, the connection is closed. +func (r *Router) sendToHTTP(conn net.Conn) { + if r.httpCh == nil { + _ = conn.Close() + return + } + + r.mu.RLock() + draining := r.draining + r.mu.RUnlock() + + if draining { + _ = conn.Close() + return + } + + select { + case r.httpCh <- conn: + default: + r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr()) + _ = conn.Close() + } +} + +// Drain prevents new relay connections from starting and waits for all +// in-flight connection handlers and active relays to finish, up to the +// given timeout. Returns true if all completed, false on timeout. +func (r *Router) Drain(timeout time.Duration) bool { + r.mu.Lock() + r.draining = true + if r.drainDone == nil { + done := make(chan struct{}) + go func() { + r.activeConns.Wait() + r.activeRelays.Wait() + close(done) + }() + r.drainDone = done + } + done := r.drainDone + r.mu.Unlock() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// cancelServiceLocked cancels and removes the context for the given service, +// closing all its active relay connections. Must be called with mu held. +func (r *Router) cancelServiceLocked(svcID types.ServiceID) { + if cancel, ok := r.svcCancels[svcID]; ok { + cancel() + delete(r.svcCtxs, svcID) + delete(r.svcCancels, svcID) + } +} + +// SetGeo sets the geolocation lookup used for country-based restrictions. +func (r *Router) SetGeo(geo restrict.GeoResolver) { + r.mu.Lock() + defer r.mu.Unlock() + r.geo = geo +} + +// checkRestrictions evaluates the route's access filter against the +// connection's remote address. Returns Allow if the connection is +// permitted, or a deny verdict indicating the reason. +func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict { + if route.Filter == nil { + return restrict.Allow + } + + addr, err := addrFromConn(conn) + if err != nil { + r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr()) + return restrict.DenyCIDR + } + + r.mu.RLock() + geo := r.geo + r.mu.RUnlock() + + return route.Filter.Check(addr, geo) +} + +// relayTCP sets up and runs a bidirectional TCP relay. +// The caller owns conn and must close it if this method returns an error. +// On success (nil error), both conn and backend are closed by the relay. +func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { + if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow { + if route.Filter != nil && route.Filter.IsObserveOnly(verdict) { + r.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", conn.RemoteAddr(), sni, verdict) + r.logL4Deny(route, conn, verdict, true) + } else { + r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict) + r.logL4Deny(route, conn, verdict, false) + return errAccessRestricted + } + } + + svcCtx, err := r.acquireRelay(ctx, route) + if err != nil { + return err + } + defer func() { + <-r.relaySem + r.activeRelays.Done() + }() + + backend, err := r.dialBackend(svcCtx, route) + if err != nil { + obs := r.getObserver() + if obs != nil { + obs.TCPRelayDialError(route.AccountID) + } + return err + } + + if route.ProxyProtocol { + if err := writeProxyProtoV2(conn, backend); err != nil { + _ = backend.Close() + return fmt.Errorf("write PROXY protocol header: %w", err) + } + } + + obs := r.getObserver() + if obs != nil { + obs.TCPRelayStarted(route.AccountID) + } + + entry := r.logger.WithFields(log.Fields{ + "sni": sni, + "service_id": route.ServiceID, + "target": route.Target, + }) + entry.Debug("TCP relay started") + + idleTimeout := route.SessionIdleTimeout + if idleTimeout <= 0 { + idleTimeout = DefaultIdleTimeout + } + + start := time.Now() + s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout) + elapsed := time.Since(start) + + if obs != nil { + obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s) + } + entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s) + + r.logL4Entry(route, conn, elapsed, s2d, d2s) + return nil +} + +// acquireRelay checks draining state, increments activeRelays, and acquires +// a semaphore slot. Returns the per-service context on success. +// The caller must release the semaphore and call activeRelays.Done() when done. +func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) { + r.mu.Lock() + if r.draining { + r.mu.Unlock() + return nil, errors.New("router is draining") + } + r.activeRelays.Add(1) + svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID) + r.mu.Unlock() + + select { + case r.relaySem <- struct{}{}: + return svcCtx, nil + default: + r.activeRelays.Done() + obs := r.getObserver() + if obs != nil { + obs.TCPRelayRejected(route.AccountID) + } + return nil, errors.New("TCP relay connection limit reached") + } +} + +// dialBackend resolves the dialer for the route's account and dials the backend. +func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) { + dialFn, err := r.dialResolve(route.AccountID) + if err != nil { + return nil, fmt.Errorf("resolve dialer: %w", err) + } + + dialTimeout := route.DialTimeout + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout) + backend, err := dialFn(dialCtx, "tcp", route.Target) + dialCancel() + if err != nil { + return nil, fmt.Errorf("dial backend %s: %w", route.Target, err) + } + return backend, nil +} + +// logL4Entry sends a TCP relay access log entry if an access logger is configured. +func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + sourceIP, _ := addrFromConn(conn) + + al.LogL4(accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DurationMs: duration.Milliseconds(), + BytesUpload: bytesUp, + BytesDownload: bytesDown, + }) +} + +// logL4Deny sends an access log entry for a denied connection. +func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict, observeOnly bool) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + sourceIP, _ := addrFromConn(conn) + + entry := accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DenyReason: verdict.String(), + } + if verdict.IsCrowdSec() { + entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()} + if observeOnly { + entry.Metadata["crowdsec_mode"] = "observe" + entry.DenyReason = "" + } + } + al.LogL4(entry) +} + +// getOrCreateServiceCtxLocked returns the context for a service, creating one +// if it doesn't exist yet. The context is a child of the server context. +// Must be called with mu held. +func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context { + if ctx, ok := r.svcCtxs[svcID]; ok { + return ctx + } + ctx, cancel := context.WithCancel(parent) + r.svcCtxs[svcID] = ctx + r.svcCancels[svcID] = cancel + return ctx +} + +// addrFromConn extracts a netip.Addr from a connection's remote address. +func addrFromConn(conn net.Conn) (netip.Addr, error) { + remote := conn.RemoteAddr() + if remote == nil { + return netip.Addr{}, errors.New("no remote address") + } + ap, err := netip.ParseAddrPort(remote.String()) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go new file mode 100644 index 000000000..93b6560f4 --- /dev/null +++ b/proxy/internal/tcp/router_test.go @@ -0,0 +1,1741 @@ +package tcp + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRouter_HTTPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Dial in a goroutine. The TLS handshake will block since nothing + // completes it on the HTTP side, but we only care about routing. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + // Send a TLS ClientHello manually. + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + // Verify the connection was routed to the HTTP channel. + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("no connection received on HTTP channel") + } +} + +func TestRouter_TCPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Set up a TLS backend that the relay will connect to. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + backendAddr := backendLn.Addr().String() + + // Accept one connection on the backend, echo data back. + backendReady := make(chan struct{}) + go func() { + close(backendReady) + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + <-backendReady + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendAddr, + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Connect as a TLS client; the proxy should passthrough to the backend. + clientConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer clientConn.Close() + + testData := []byte("hello through TCP passthrough") + _, err = clientConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := clientConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through TCP passthrough") +} + +func TestRouter_UnknownSNIGoesToHTTP(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + // No routes registered. + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("unknown SNI should be routed to HTTP") + } +} + +// TestRouter_NonTLSConnectionDropped verifies that a non-TLS connection +// on the shared port is closed by the router (SNI peek fails to find a +// valid ClientHello, so there is no route match). +func TestRouter_NonTLSConnectionDropped(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Register a TLS passthrough route. Non-TLS should NOT match. + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: "127.0.0.1:9999", + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Send plain HTTP (non-TLS) data. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: tcp.example.com\r\n\r\n")) + + // Non-TLS traffic on a port with RouteTCP goes to the HTTP channel + // because there's no valid SNI to match. Verify it reaches HTTP. + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "non-TLS connection should fall through to HTTP") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("non-TLS connection was not routed to HTTP") + } +} + +// TestRouter_TLSAndHTTPCoexist verifies that a shared port with both HTTP +// and TLS passthrough routes correctly demuxes based on the SNI hostname. +func TestRouter_TLSAndHTTPCoexist(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes data. + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + // HTTP route. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // TLS passthrough route. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // 1. TLS connection with SNI "tcp.example.com" → TLS passthrough. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + testData := []byte("passthrough data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "TLS passthrough should relay data") + tlsConn.Close() + + // 2. TLS connection with SNI "app.example.com" → HTTP handler. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "HTTP SNI should go to HTTP handler") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } +} + +func TestRouter_AddRemoveRoute(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.1:5432"}) + + route, ok := router.lookupRoute("a.example.com") + assert.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type) + + route, ok = router.lookupRoute("b.example.com") + assert.True(t, ok) + assert.Equal(t, RouteTCP, route.Type) + + router.RemoveRoute("a.example.com", "svc-a") + _, ok = router.lookupRoute("a.example.com") + assert.False(t, ok) +} + +func TestChanListener_AcceptAndClose(t *testing.T) { + ch := make(chan net.Conn, 1) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + ln := newChanListener(ch, addr) + + assert.Equal(t, addr, ln.Addr()) + + // Send a connection. + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + ch <- serverConn + + conn, err := ln.Accept() + require.NoError(t, err) + assert.Equal(t, serverConn, conn) + + // Close should cause Accept to return error. + require.NoError(t, ln.Close()) + // Double close should be safe. + require.NoError(t, ln.Close()) + + _, err = ln.Accept() + assert.ErrorIs(t, err, net.ErrClosed) +} + +func TestRouter_HTTPPrecedenceGuard(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + host := SNIHost("app.example.com") + + t.Run("http takes precedence over tcp at lookup", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP route must take precedence over TCP") + assert.Equal(t, types.ServiceID("svc-http"), route.ServiceID) + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("tcp becomes active when http is removed", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + router.RemoveRoute(host, "svc-http") + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "TCP should take over after HTTP removal") + assert.Equal(t, types.ServiceID("svc-tcp"), route.ServiceID) + + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("order of add does not matter", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP takes precedence regardless of add order") + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("same service id updates in place", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.2:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, "10.0.0.2:443", route.Target, "route should be updated in place") + + router.RemoveRoute(host, "svc-1") + _, ok = router.lookupRoute(host) + assert.False(t, ok) + }) + + t.Run("double remove is safe", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-1"}) + router.RemoveRoute(host, "svc-1") + router.RemoveRoute(host, "svc-1") + + _, ok := router.lookupRoute(host) + assert.False(t, ok, "route should be gone after removal") + }) + + t.Run("remove does not affect other hosts", func(t *testing.T) { + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.2:22"}) + + router.RemoveRoute("a.example.com", "svc-a") + + _, ok := router.lookupRoute(SNIHost("a.example.com")) + assert.False(t, ok) + + route, ok := router.lookupRoute(SNIHost("b.example.com")) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "removing one host must not affect another") + + router.RemoveRoute("b.example.com", "svc-b") + }) +} + +func TestRouter_SetRemoveFallback(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new port router should be empty") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb", Target: "10.0.0.1:5432"}) + assert.False(t, router.IsEmpty(), "router with fallback should not be empty") + + router.AddRoute("a.example.com", Route{Type: RouteTCP, ServiceID: "svc-a", Target: "10.0.0.2:443"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb") + assert.False(t, router.IsEmpty(), "router with SNI route should not be empty") + + router.RemoveRoute("a.example.com", "svc-a") + assert.True(t, router.IsEmpty(), "router with no routes and no fallback should be empty") +} + +func TestPortRouter_FallbackRelaysData(t *testing.T) { + // Backend echo server. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Plain TCP (non-TLS) connection should be relayed via fallback. + // Use exactly 5 bytes. PeekClientHello reads 5 bytes as the TLS + // header, so a single 5-byte write lands as one chunk at the backend. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("hello") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through fallback relay") +} + +func TestPortRouter_FallbackOnUnknownSNI(t *testing.T) { + // Backend TLS echo server. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + // Only a fallback, no SNI route for "unknown.example.com". + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with unknown SNI → fallback relay to TLS backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("hello through fallback TLS") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "unknown SNI should relay through fallback") +} + +func TestPortRouter_SNIWinsOverFallback(t *testing.T) { + // Two backend echo servers: one for SNI match, one for fallback. + sniBacked := startEchoTLS(t) + fbBacked := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "sni-service", + Target: sniBacked.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBacked.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with matching SNI should go to SNI backend, not fallback. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("SNI route data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "SNI match should use SNI route, not fallback") +} + +func TestPortRouter_NoFallbackNoHTTP_Closes(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("hello")) + + // Connection should be closed by the router (no fallback, no HTTP). + buf := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be closed when no fallback and no HTTP channel") +} + +func TestRouter_FallbackAndHTTPCoexist(t *testing.T) { + // Fallback backend echo server (plain TCP). + fbBackend, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer fbBackend.Close() + + go func() { + conn, err := fbBackend.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + + // HTTP route for known SNI. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // Fallback for non-TLS / unknown SNI. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with known HTTP SNI → should go to HTTP channel. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "known HTTP SNI should go to HTTP channel") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } + + // 2. Plain TCP (non-TLS) → should go to fallback, not HTTP. + // Use exactly 5 bytes to match PeekClientHello header size. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("plain") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "non-TLS should be relayed via fallback, not HTTP") +} + +// startEchoTLS starts a TLS echo server and returns the listener. +func startEchoTLS(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + if _, err := conn.Write(buf[:n]); err != nil { + return + } + } + }() + + return ln +} + +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"tcp.example.com"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + } +} + +func TestRouter_DrainWaitsForRelays(t *testing.T) { + logger := log.StandardLogger() + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Accept connections: echo first message, then hold open until told to close. + closeBackend := make(chan struct{}) + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + <-closeBackend + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func() { + _ = router.Serve(ctx, ln) + close(serveDone) + }() + + // Open a relay connection (non-TLS, hits fallback). + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + _, _ = conn.Write([]byte("hello")) + + // Wait for the echo to confirm the relay is fully established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Drain with a short timeout should fail because the relay is still active. + assert.False(t, router.Drain(50*time.Millisecond), "drain should timeout with active relay") + + // Close backend connections so relays finish. + close(closeBackend) + _ = conn.Close() + + // Drain should now complete quickly. + assert.True(t, router.Drain(2*time.Second), "drain should succeed after relays end") + + cancel() + <-serveDone +} + +func TestRouter_DrainEmptyReturnsImmediately(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + start := time.Now() + ok := router.Drain(5 * time.Second) + elapsed := time.Since(start) + + assert.True(t, ok) + assert.Less(t, elapsed, 100*time.Millisecond, "drain with no relays should return immediately") +} + +// TestRemoveRoute_KillsActiveRelays verifies that removing a route +// immediately kills active relay connections for that service. +func TestRemoveRoute_KillsActiveRelays(t *testing.T) { + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes first message, then holds connection open. + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + // Hold the connection open. + for { + if _, err := c.Read(buf); err != nil { + return + } + } + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-1", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a relay connection. + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("hello")) + require.NoError(t, err) + + // Wait for echo to confirm relay is established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Remove the fallback: should kill the active relay. + router.RemoveFallback("svc-1") + + // The client connection should see an error (server closed). + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be killed after service removal") +} + +// TestRemoveRoute_KillsSNIRelays verifies that removing an SNI route +// kills its active relays without affecting other services. +func TestRemoveRoute_KillsSNIRelays(t *testing.T) { + backend := startEchoTLS(t) + defer backend.Close() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tls.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-tls", + Target: backend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a TLS relay. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + defer tlsConn.Close() + + _, err = tlsConn.Write([]byte("ping")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Remove the route: active relay should die. + router.RemoveRoute("tls.example.com", "svc-tls") + + _ = tlsConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = tlsConn.Read(buf) + assert.Error(t, err, "TLS relay should be killed after route removal") +} + +// TestPortRouter_SNIAndTCPFallbackCoexist verifies that a single port can +// serve both SNI-routed TLS passthrough and plain TCP fallback simultaneously. +func TestPortRouter_SNIAndTCPFallbackCoexist(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + // SNI route for a specific domain. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-sni", + Target: sniBackend.Addr().String(), + }) + // TCP fallback for everything else. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with matching SNI → goes to SNI backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = tlsConn.Write([]byte("sni-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-data", string(buf[:n]), "SNI match → SNI backend") + tlsConn.Close() + + // 2. Plain TCP (no TLS) → goes to fallback. + tcpConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + + _, err = tcpConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = tcpConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n]), "plain TCP → fallback backend") + tcpConn.Close() + + // 3. TLS with unknown SNI → also goes to fallback. + unknownBackend := startEchoTLS(t) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: unknownBackend.Addr().String(), + }) + + unknownTLS, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = unknownTLS.Write([]byte("unknown-sni")) + require.NoError(t, err) + n, err = unknownTLS.Read(buf) + require.NoError(t, err) + assert.Equal(t, "unknown-sni", string(buf[:n]), "unknown SNI → fallback backend") + unknownTLS.Close() +} + +// TestPortRouter_UpdateRouteSwapsSNI verifies that updating a route +// (remove + add with different target) correctly routes to the new backend. +func TestPortRouter_UpdateRouteSwapsSNI(t *testing.T) { + backend1 := startEchoTLS(t) + backend2 := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Initial route → backend1. + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend1.Addr().String(), + }) + + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("v1")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v1", string(buf[:n])) + conn1.Close() + + // Update: remove old route, add new → backend2. + router.RemoveRoute("db.example.com", "svc-db") + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend2.Addr().String(), + }) + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("v2")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v2", string(buf[:n])) + conn2.Close() +} + +// TestPortRouter_RemoveSNIFallsThrough verifies that after removing an +// SNI route, connections for that domain fall through to the fallback. +func TestPortRouter_RemoveSNIFallsThrough(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: sniBackend.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Before removal: SNI matches → sniBackend. + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("before")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "before", string(buf[:n])) + conn1.Close() + + // Remove SNI route. Should fall through to fallback. + router.RemoveRoute("db.example.com", "svc-db") + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("after")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n]), "after removal, should reach fallback") + conn2.Close() +} + +// TestPortRouter_RemoveFallbackCloses verifies that after removing the +// fallback, non-matching connections are closed. +func TestPortRouter_RemoveFallbackCloses(t *testing.T) { + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // With fallback: plain TCP works. + conn1, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = conn1.Write([]byte("hello")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + conn1.Close() + + // Remove fallback. + router.RemoveFallback("svc-fb") + + // Without fallback on a port router (no HTTP channel): connection should be closed. + conn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn2.Close() + _, _ = conn2.Write([]byte("bye")) + _ = conn2.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn2.Read(buf) + assert.Error(t, err, "without fallback, connection should be closed") +} + +// TestPortRouter_HTTPToTLSTransition verifies that switching a service from +// HTTP-only to TLS-only via remove+add doesn't orphan the old HTTP route. +func TestPortRouter_HTTPToTLSTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: HTTP-only. SNI connections go to HTTP channel. + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener did not receive connection for HTTP-only route") + } + + // Phase 2: Simulate update to TLS-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS connection should succeed after HTTP→TLS transition") + defer tlsConn2.Close() + + _, err = tlsConn2.Write([]byte("hello-tls")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello-tls", string(buf[:n]), "data should relay to TLS backend") +} + +// TestPortRouter_TLSToHTTPTransition verifies that switching a service from +// TLS-only to HTTP-only via remove+add doesn't orphan the old TLS route. +func TestPortRouter_TLSToHTTPTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: TLS-only. Route relays to backend. + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS relay should work before transition") + _, err = tlsConn.Write([]byte("tls-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "tls-data", string(buf[:n])) + tlsConn.Close() + + // Phase 2: Simulate update to HTTP-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + // TLS connection should now go to the HTTP listener, NOT to the old TLS backend. + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn2.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener should receive connection after TLS→HTTP transition") + } +} + +// TestPortRouter_MultiDomainSamePort verifies that two TLS services sharing +// the same port router are independently routable and removable. +func TestPortRouter_MultiDomainSamePort(t *testing.T) { + logger := log.StandardLogger() + backend1 := startEchoTLSMulti(t) + backend2 := startEchoTLSMulti(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + router.AddRoute("svc1.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: backend1.Addr().String()}) + router.AddRoute("svc2.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-2", Target: backend2.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // Both domains route independently. + for _, tc := range []struct { + sni string + data string + }{ + {"svc1.example.com", "hello-svc1"}, + {"svc2.example.com", "hello-svc2"}, + } { + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: tc.sni, InsecureSkipVerify: true}, + ) + require.NoError(t, err, "dial %s", tc.sni) + _, err = conn.Write([]byte(tc.data)) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, tc.data, string(buf[:n])) + conn.Close() + } + + // Remove svc1. Router should NOT be empty (svc2 still present). + router.RemoveRoute("svc1.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "router should not be empty with one route remaining") + + // svc2 still works. + conn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "svc2.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = conn2.Write([]byte("still-alive")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "still-alive", string(buf[:n])) + conn2.Close() + + // Remove svc2. Router is now empty. + router.RemoveRoute("svc2.example.com", "svc-2") + assert.True(t, router.IsEmpty(), "router should be empty after removing all routes") +} + +// TestPortRouter_SNIAndFallbackLifecycle verifies the full lifecycle of SNI +// routes and TCP fallback coexisting on the same port router, including the +// ordering of add/remove operations. +func TestPortRouter_SNIAndFallbackLifecycle(t *testing.T) { + logger := log.StandardLogger() + sniBackend := startEchoTLS(t) + fallbackBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Step 1: Add fallback first (port mapping), then SNI route (TLS service). + router.SetFallback(Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "pm-1", Target: fallbackBackend.Addr().String()}) + router.AddRoute("tls.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: sniBackend.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // SNI traffic goes to TLS backend. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = tlsConn.Write([]byte("sni-traffic")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-traffic", string(buf[:n])) + tlsConn.Close() + + // Plain TCP goes to fallback. + plainConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = plainConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = plainConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n])) + plainConn.Close() + + // Step 2: Remove SNI route. Fallback still works, router not empty. + router.RemoveRoute("tls.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "fallback still present") + + plainConn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + // Must send >= 5 bytes so the SNI peek completes immediately + // without waiting for the 5-second peek timeout. + _, err = plainConn2.Write([]byte("after")) + require.NoError(t, err) + n, err = plainConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n])) + plainConn2.Close() + + // Step 3: Remove fallback. Router is now empty. + router.RemoveFallback("pm-1") + assert.True(t, router.IsEmpty()) +} + +// TestPortRouter_IsEmptyTransitions verifies IsEmpty reflects correct state +// through all add/remove operations. +func TestPortRouter_IsEmptyTransitions(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new router") + + router.AddRoute("a.com", Route{Type: RouteTCP, ServiceID: "svc-a"}) + assert.False(t, router.IsEmpty(), "after adding route") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb1"}) + assert.False(t, router.IsEmpty(), "route + fallback") + + router.RemoveRoute("a.com", "svc-a") + assert.False(t, router.IsEmpty(), "fallback only") + + router.RemoveFallback("svc-fb1") + assert.True(t, router.IsEmpty(), "all removed") + + // Reverse order: fallback first, then route. + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb2"}) + assert.False(t, router.IsEmpty()) + + router.AddRoute("b.com", Route{Type: RouteTCP, ServiceID: "svc-b"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb2") + assert.False(t, router.IsEmpty(), "route still present") + + router.RemoveRoute("b.com", "svc-b") + assert.True(t, router.IsEmpty(), "fully empty again") +} + +// startEchoTLSMulti starts a TLS echo server that accepts multiple connections. +func startEchoTLSMulti(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} + +// startEchoPlain starts a plain TCP echo server that reads until newline +// or connection close, then echoes the received data. +func startEchoPlain(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + // Set a read deadline so we don't block forever waiting for more data. + _ = c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} + +// fakeAddr implements net.Addr with a custom string representation. +type fakeAddr string + +func (f fakeAddr) Network() string { return "tcp" } +func (f fakeAddr) String() string { return string(f) } + +// fakeConn is a minimal net.Conn with a controllable RemoteAddr. +type fakeConn struct { + net.Conn + remote net.Addr +} + +func (f *fakeConn) RemoteAddr() net.Addr { return f.remote } + +func TestCheckRestrictions_UnparseableAddress(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + route := Route{Filter: filter} + + conn := &fakeConn{remote: fakeAddr("not-an-ip")} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied") +} + +func TestCheckRestrictions_NilRemoteAddr(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + route := Route{Filter: filter} + + conn := &fakeConn{remote: nil} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied") +} + +func TestCheckRestrictions_AllowedAndDenied(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + route := Route{Filter: filter} + + allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}} + assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist") + + denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist") +} + +func TestCheckRestrictions_NilFilter(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + route := Route{Filter: nil} + + conn := &fakeConn{remote: fakeAddr("not-an-ip")} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything") +} + +func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}) + route := Route{Filter: filter} + + // net.IPv4() returns a 16-byte v4-in-v6 representation internally. + // The restriction check must Unmap it to match the v4 CIDR. + conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR") + + // Explicitly v4-mapped-v6 address string. + conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR") + + connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR") +} diff --git a/proxy/internal/tcp/snipeek.go b/proxy/internal/tcp/snipeek.go new file mode 100644 index 000000000..25ab8e5ef --- /dev/null +++ b/proxy/internal/tcp/snipeek.go @@ -0,0 +1,191 @@ +package tcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" +) + +const ( + // TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2). + tlsRecordHeaderLen = 5 + // TLS handshake type for ClientHello. + handshakeTypeClientHello = 1 + // TLS ContentType for handshake messages. + contentTypeHandshake = 22 + // SNI extension type (RFC 6066). + extensionServerName = 0 + // SNI host name type. + sniHostNameType = 0 + // maxClientHelloLen caps the ClientHello size we're willing to buffer. + maxClientHelloLen = 16384 + // maxSNILen is the maximum valid DNS hostname length per RFC 1035. + maxSNILen = 253 +) + +// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI +// server name, and returns a wrapped connection that replays the peeked +// bytes transparently. If the data is not a valid TLS ClientHello or +// contains no SNI extension, sni is empty and err is nil. +// +// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the +// real server name is encrypted inside the encrypted_client_hello +// extension. This parser only reads the cleartext server_name extension +// (type 0x0000), so ECH connections return sni="" and are routed through +// the fallback path (or HTTP channel), which is the correct behavior +// for a transparent proxy that does not terminate TLS. +func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { + // Read the 5-byte TLS record header into a small stack-friendly buffer. + var header [tlsRecordHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return "", nil, fmt.Errorf("read TLS record header: %w", err) + } + + if header[0] != contentTypeHandshake { + return "", newPeekedConn(conn, header[:]), nil + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + if recordLen == 0 || recordLen > maxClientHelloLen { + return "", newPeekedConn(conn, header[:]), nil + } + + // Single allocation for header + payload. The peekedConn takes + // ownership of this buffer, so no further copies are needed. + buf := make([]byte, tlsRecordHeaderLen+recordLen) + copy(buf, header[:]) + + n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:]) + if err != nil { + return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err) + } + + sni = extractSNI(buf[tlsRecordHeaderLen:]) + return sni, newPeekedConn(conn, buf), nil +} + +// extractSNI parses a TLS handshake payload to find the SNI extension. +// Returns empty string if the payload is not a ClientHello or has no SNI. +func extractSNI(payload []byte) string { + if len(payload) < 4 { + return "" + } + + if payload[0] != handshakeTypeClientHello { + return "" + } + + // Handshake length (3 bytes, big-endian). + handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3]) + if handshakeLen > len(payload)-4 { + return "" + } + + return parseSNIFromClientHello(payload[4 : 4+handshakeLen]) +} + +// parseSNIFromClientHello walks the ClientHello message fields to reach +// the extensions block and extract the server_name extension value. +func parseSNIFromClientHello(msg []byte) string { + // ClientHello layout: + // ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id + if len(msg) < 34 { + return "" + } + + pos := 34 + + // Session ID (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + sessionIDLen := int(msg[pos]) + pos++ + pos += sessionIDLen + + // Cipher suites (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + cipherSuitesLen + + // Compression methods (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + compMethodsLen := int(msg[pos]) + pos++ + pos += compMethodsLen + + // Extensions (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + + extensionsEnd := pos + extensionsLen + if extensionsEnd > len(msg) { + return "" + } + + return findSNIExtension(msg[pos:extensionsEnd]) +} + +// findSNIExtension iterates over TLS extensions and returns the host +// name from the server_name extension, if present. +func findSNIExtension(extensions []byte) string { + pos := 0 + for pos+4 <= len(extensions) { + extType := binary.BigEndian.Uint16(extensions[pos : pos+2]) + extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4])) + pos += 4 + + if pos+extLen > len(extensions) { + return "" + } + + if extType == extensionServerName { + return parseSNIExtensionData(extensions[pos : pos+extLen]) + } + pos += extLen + } + return "" +} + +// parseSNIExtensionData parses the ServerNameList structure inside an +// SNI extension to extract the host name. +func parseSNIExtensionData(data []byte) string { + if len(data) < 2 { + return "" + } + listLen := int(binary.BigEndian.Uint16(data[0:2])) + if listLen > len(data)-2 { + return "" + } + + list := data[2 : 2+listLen] + pos := 0 + for pos+3 <= len(list) { + nameType := list[pos] + nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3])) + pos += 3 + + if pos+nameLen > len(list) { + return "" + } + + if nameType == sniHostNameType { + name := list[pos : pos+nameLen] + if nameLen > maxSNILen || bytes.ContainsRune(name, 0) { + return "" + } + return string(name) + } + pos += nameLen + } + return "" +} diff --git a/proxy/internal/tcp/snipeek_test.go b/proxy/internal/tcp/snipeek_test.go new file mode 100644 index 000000000..9afe6261d --- /dev/null +++ b/proxy/internal/tcp/snipeek_test.go @@ -0,0 +1,251 @@ +package tcp + +import ( + "crypto/tls" + "io" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPeekClientHello_ValidSNI(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + const expectedSNI = "example.com" + trailingData := []byte("trailing data after handshake") + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: expectedSNI, + InsecureSkipVerify: true, //nolint:gosec + }) + // The Handshake will send the ClientHello. It will fail because + // our server side isn't doing a real TLS handshake, but that's + // fine: we only need the ClientHello to be sent. + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello") + assert.NotNil(t, wrapped, "wrapped connection should not be nil") + + // Verify the wrapped connection replays the peeked bytes. + // Read the first 5 bytes (TLS record header) to confirm replay. + buf := make([]byte, 5) + n, err := wrapped.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type") + + // Write trailing data from the client side and verify it arrives + // through the wrapped connection after the peeked bytes. + go func() { + _, _ = clientConn.Write(trailingData) + }() + + // Drain the rest of the peeked ClientHello first. + peekedRest := make([]byte, 16384) + _, _ = wrapped.Read(peekedRest) + + got := make([]byte, len(trailingData)) + n, err = io.ReadFull(wrapped, got) + require.NoError(t, err) + assert.Equal(t, trailingData, got[:n]) +} + +func TestPeekClientHello_MultipleSNIs(t *testing.T) { + tests := []struct { + name string + serverName string + expectedSNI string + }{ + {"simple domain", "example.com", "example.com"}, + {"subdomain", "sub.example.com", "sub.example.com"}, + {"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: tt.serverName, + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, tt.expectedSNI, sni) + assert.NotNil(t, wrapped) + }) + } +} + +func TestPeekClientHello_NonTLSData(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // Send plain HTTP data (not TLS). + httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + go func() { + _, _ = clientConn.Write(httpData) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni, "should return empty SNI for non-TLS data") + assert.NotNil(t, wrapped) + + // Verify the wrapped connection still provides the original data. + buf := make([]byte, len(httpData)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data") +} + +func TestPeekClientHello_TruncatedHeader(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write only 3 bytes then close, fewer than the 5-byte TLS header. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01}) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated header") +} + +func TestPeekClientHello_TruncatedPayload(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write a valid TLS header claiming 100 bytes, but only send 10. + go func() { + header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed + _, _ = clientConn.Write(header) + _, _ = clientConn.Write(make([]byte, 10)) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated payload") +} + +func TestPeekClientHello_ZeroLengthRecord(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // TLS handshake header with zero-length payload. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00}) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni) + assert.NotNil(t, wrapped) +} + +func TestExtractSNI_InvalidPayload(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + {"nil", nil}, + {"empty", []byte{}}, + {"too short", []byte{0x01, 0x00}}, + {"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}}, + {"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Empty(t, extractSNI(tt.payload)) + }) + } +} + +func TestPeekedConn_CloseWrite(t *testing.T) { + t.Run("delegates to underlying TCPConn", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accepted <- c + } + }() + + client, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer client.Close() + + server := <-accepted + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + + // CloseWrite should succeed on a real TCP connection. + err = wrapped.CloseWrite() + assert.NoError(t, err) + + // The client should see EOF on reads after CloseWrite. + buf := make([]byte, 1) + _, err = client.Read(buf) + assert.Equal(t, io.EOF, err, "client should see EOF after half-close") + }) + + t.Run("no-op on non-halfcloser", func(t *testing.T) { + // net.Pipe does not implement CloseWrite. + _, server := net.Pipe() + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + err := wrapped.CloseWrite() + assert.NoError(t, err, "should be no-op on non-halfcloser") + }) +} + +func TestPeekedConn_ReplayAndPassthrough(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + peeked := []byte("peeked-data") + subsequent := []byte("subsequent-data") + + wrapped := newPeekedConn(serverConn, peeked) + + go func() { + _, _ = clientConn.Write(subsequent) + }() + + // Read should return peeked data first. + buf := make([]byte, len(peeked)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, peeked, buf[:n]) + + // Then subsequent data from the real connection. + buf = make([]byte, len(subsequent)) + n, err = io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, subsequent, buf[:n]) +} diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go index 41acfef40..bf3731803 100644 --- a/proxy/internal/types/types.go +++ b/proxy/internal/types/types.go @@ -1,5 +1,56 @@ // Package types defines common types used across the proxy package. package types +import ( + "context" + "net" + "time" +) + // AccountID represents a unique identifier for a NetBird account. type AccountID string + +// ServiceID represents a unique identifier for a proxy service. +type ServiceID string + +// ServiceMode describes how a reverse proxy service is exposed. +type ServiceMode string + +const ( + ServiceModeHTTP ServiceMode = "http" + ServiceModeTCP ServiceMode = "tcp" + ServiceModeUDP ServiceMode = "udp" + ServiceModeTLS ServiceMode = "tls" +) + +// IsL4 returns true for TCP, UDP, and TLS modes. +func (m ServiceMode) IsL4() bool { + return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS +} + +// RelayDirection indicates the direction of a relayed packet. +type RelayDirection string + +const ( + RelayDirectionClientToBackend RelayDirection = "client_to_backend" + RelayDirectionBackendToClient RelayDirection = "backend_to_client" +) + +// DialContextFunc dials a backend through the WireGuard tunnel. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// dialTimeoutKey is the context key for a per-request dial timeout. +type dialTimeoutKey struct{} + +// WithDialTimeout returns a context carrying a dial timeout that +// DialContext wrappers can use to scope the timeout to just the +// connection establishment phase. +func WithDialTimeout(ctx context.Context, d time.Duration) context.Context { + return context.WithValue(ctx, dialTimeoutKey{}, d) +} + +// DialTimeoutFromContext returns the dial timeout from the context, if set. +func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) { + d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration) + return d, ok && d > 0 +} diff --git a/proxy/internal/types/types_test.go b/proxy/internal/types/types_test.go new file mode 100644 index 000000000..dd9738442 --- /dev/null +++ b/proxy/internal/types/types_test.go @@ -0,0 +1,54 @@ +package types + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestServiceMode_IsL4(t *testing.T) { + tests := []struct { + mode ServiceMode + want bool + }{ + {ServiceModeHTTP, false}, + {ServiceModeTCP, true}, + {ServiceModeUDP, true}, + {ServiceModeTLS, true}, + {ServiceMode("unknown"), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + assert.Equal(t, tt.want, tt.mode.IsL4()) + }) + } +} + +func TestDialTimeoutContext(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 5*time.Second) + d, ok := DialTimeoutFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, 5*time.Second, d) + }) + + t.Run("missing", func(t *testing.T) { + _, ok := DialTimeoutFromContext(context.Background()) + assert.False(t, ok) + }) + + t.Run("zero returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 0) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "zero duration should return ok=false") + }) + + t.Run("negative returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), -1*time.Second) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "negative duration should return ok=false") + }) +} diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go new file mode 100644 index 000000000..8293bfe81 --- /dev/null +++ b/proxy/internal/udp/relay.go @@ -0,0 +1,573 @@ +package udp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +const ( + // DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup. + DefaultSessionTTL = 30 * time.Second + // cleanupInterval is how often the cleaner goroutine runs. + cleanupInterval = time.Minute + // maxPacketSize is the maximum UDP packet size we'll handle. + maxPacketSize = 65535 + // DefaultMaxSessions is the default cap on concurrent UDP sessions per relay. + DefaultMaxSessions = 1024 + // sessionCreateRate limits new session creation per second. + sessionCreateRate = 50 + // sessionCreateBurst is the burst allowance for session creation. + sessionCreateBurst = 100 + // defaultDialTimeout is the fallback dial timeout for backend connections. + defaultDialTimeout = 30 * time.Second +) + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// SessionObserver receives callbacks for UDP session lifecycle events. +// All methods must be safe for concurrent use. +type SessionObserver interface { + UDPSessionStarted(accountID types.AccountID) + UDPSessionEnded(accountID types.AccountID) + UDPSessionDialError(accountID types.AccountID) + UDPSessionRejected(accountID types.AccountID) + UDPPacketRelayed(direction types.RelayDirection, bytes int) +} + +// clientAddr is a typed key for UDP session lookups. +type clientAddr string + +// Relay listens for incoming UDP packets on a dedicated port and +// maintains per-client sessions that relay packets to a backend +// through the WireGuard tunnel. +type Relay struct { + logger *log.Entry + listener net.PacketConn + target string + domain string + accountID types.AccountID + serviceID types.ServiceID + dialFunc types.DialContextFunc + dialTimeout time.Duration + sessionTTL time.Duration + maxSessions int + filter *restrict.Filter + geo restrict.GeoResolver + + mu sync.RWMutex + sessions map[clientAddr]*session + + bufPool sync.Pool + sessLimiter *rate.Limiter + sessWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + observer SessionObserver + accessLog l4Logger +} + +type session struct { + backend net.Conn + addr net.Addr + createdAt time.Time + // lastSeen stores the last activity timestamp as unix nanoseconds. + lastSeen atomic.Int64 + cancel context.CancelFunc + // bytesIn tracks total bytes received from the client. + bytesIn atomic.Int64 + // bytesOut tracks total bytes sent back to the client. + bytesOut atomic.Int64 +} + +func (s *session) updateLastSeen() { + s.lastSeen.Store(time.Now().UnixNano()) +} + +func (s *session) idleDuration() time.Duration { + return time.Since(time.Unix(0, s.lastSeen.Load())) +} + +// RelayConfig holds the configuration for a UDP relay. +type RelayConfig struct { + Logger *log.Entry + Listener net.PacketConn + Target string + Domain string + AccountID types.AccountID + ServiceID types.ServiceID + DialFunc types.DialContextFunc + DialTimeout time.Duration + SessionTTL time.Duration + MaxSessions int + AccessLog l4Logger + // Filter holds connection-level IP/geo restrictions. Nil means no restrictions. + Filter *restrict.Filter + // Geo is the geolocation lookup used for country-based restrictions. + Geo restrict.GeoResolver +} + +// New creates a UDP relay for the given listener and backend target. +// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions. +// DialTimeout controls how long to wait for backend connections; use 0 for default. +// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL. +func New(parentCtx context.Context, cfg RelayConfig) *Relay { + maxSessions := cfg.MaxSessions + dialTimeout := cfg.DialTimeout + sessionTTL := cfg.SessionTTL + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + if sessionTTL <= 0 { + sessionTTL = DefaultSessionTTL + } + ctx, cancel := context.WithCancel(parentCtx) + return &Relay{ + logger: cfg.Logger, + listener: cfg.Listener, + target: cfg.Target, + domain: cfg.Domain, + accountID: cfg.AccountID, + serviceID: cfg.ServiceID, + accessLog: cfg.AccessLog, + dialFunc: cfg.DialFunc, + dialTimeout: dialTimeout, + sessionTTL: sessionTTL, + maxSessions: maxSessions, + filter: cfg.Filter, + geo: cfg.Geo, + sessions: make(map[clientAddr]*session), + bufPool: sync.Pool{ + New: func() any { + buf := make([]byte, maxPacketSize) + return &buf + }, + }, + sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst), + ctx: ctx, + cancel: cancel, + } +} + +// ServiceID returns the service ID associated with this relay. +func (r *Relay) ServiceID() types.ServiceID { + return r.serviceID +} + +// SetObserver sets the session lifecycle observer. Must be called before Serve. +func (r *Relay) SetObserver(obs SessionObserver) { + r.mu.Lock() + defer r.mu.Unlock() + r.observer = obs +} + +// getObserver returns the current session lifecycle observer. +func (r *Relay) getObserver() SessionObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + +// Serve starts the relay loop. It blocks until the context is canceled +// or the listener is closed. +func (r *Relay) Serve() { + go r.cleanupLoop() + + for { + bufp := r.bufPool.Get().(*[]byte) + buf := *bufp + + n, addr, err := r.listener.ReadFrom(buf) + if err != nil { + r.bufPool.Put(bufp) + if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + return + } + r.logger.Debugf("UDP read: %v", err) + continue + } + + data := buf[:n] + sess, err := r.getOrCreateSession(addr) + if err != nil { + r.bufPool.Put(bufp) + r.logger.Debugf("create UDP session for %s: %v", addr, err) + continue + } + + sess.updateLastSeen() + + nw, err := sess.backend.Write(data) + if err != nil { + r.bufPool.Put(bufp) + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to backend for %s: %v", addr, err) + } + r.removeSession(sess) + continue + } + sess.bytesIn.Add(int64(nw)) + + if obs := r.getObserver(); obs != nil { + obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) + } + r.bufPool.Put(bufp) + } +} + +// getOrCreateSession returns an existing session or creates a new one. +func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { + key := clientAddr(addr.String()) + + r.mu.RLock() + sess, ok := r.sessions[key] + r.mu.RUnlock() + if ok && sess != nil { + return sess, nil + } + + // Check before taking the write lock: if the relay is shutting down, + // don't create new sessions. This prevents orphaned goroutines when + // Serve() processes a packet that was already read before Close(). + if r.ctx.Err() != nil { + return nil, r.ctx.Err() + } + + if err := r.checkAccessRestrictions(addr); err != nil { + return nil, err + } + + r.mu.Lock() + + if sess, ok = r.sessions[key]; ok && sess != nil { + r.mu.Unlock() + return sess, nil + } + if ok { + // Another goroutine is dialing for this key, skip. + r.mu.Unlock() + return nil, fmt.Errorf("session dial in progress for %s", key) + } + + if len(r.sessions) >= r.maxSessions { + r.mu.Unlock() + if obs := r.getObserver(); obs != nil { + obs.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) + } + + if !r.sessLimiter.Allow() { + r.mu.Unlock() + if obs := r.getObserver(); obs != nil { + obs.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session creation rate limited") + } + + // Reserve the slot with a nil session so concurrent callers for the same + // key see it exists and wait. Release the lock before dialing. + r.sessions[key] = nil + r.mu.Unlock() + + dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout) + backend, err := r.dialFunc(dialCtx, "udp", r.target) + dialCancel() + if err != nil { + r.mu.Lock() + delete(r.sessions, key) + r.mu.Unlock() + if obs := r.getObserver(); obs != nil { + obs.UDPSessionDialError(r.accountID) + } + return nil, fmt.Errorf("dial backend %s: %w", r.target, err) + } + + sessCtx, sessCancel := context.WithCancel(r.ctx) + sess = &session{ + backend: backend, + addr: addr, + createdAt: time.Now(), + cancel: sessCancel, + } + sess.updateLastSeen() + + r.mu.Lock() + r.sessions[key] = sess + r.mu.Unlock() + + if obs := r.getObserver(); obs != nil { + obs.UDPSessionStarted(r.accountID) + } + + r.sessWg.Go(func() { + r.relayBackendToClient(sessCtx, sess) + }) + + r.logger.Debugf("UDP session created for %s", addr) + return sess, nil +} + +func (r *Relay) checkAccessRestrictions(addr net.Addr) error { + if r.filter == nil { + return nil + } + clientIP, err := addrFromUDPAddr(addr) + if err != nil { + return fmt.Errorf("parse client address %s for restriction check: %w", addr, err) + } + if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow { + if r.filter.IsObserveOnly(v) { + r.logger.Debugf("CrowdSec observe: would block %s (%s)", clientIP, v) + r.logDeny(clientIP, v, true) + } else { + r.logDeny(clientIP, v, false) + return fmt.Errorf("access restricted for %s", addr) + } + } + return nil +} + +// relayBackendToClient reads packets from the backend and writes them +// back to the client through the public-facing listener. +func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { + bufp := r.bufPool.Get().(*[]byte) + defer r.bufPool.Put(bufp) + defer r.removeSession(sess) + + for ctx.Err() == nil { + data, ok := r.readBackendPacket(sess, *bufp) + if !ok { + return + } + if data == nil { + continue + } + + sess.updateLastSeen() + + nw, err := r.listener.WriteTo(data, sess.addr) + if err != nil { + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to client %s: %v", sess.addr, err) + } + return + } + sess.bytesOut.Add(int64(nw)) + + if obs := r.getObserver(); obs != nil { + obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) + } + } +} + +// readBackendPacket reads one packet from the backend with an idle deadline. +// Returns (data, true) on success, (nil, true) on idle timeout that should +// retry, or (nil, false) when the session should be torn down. +func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) { + if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil { + r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err) + return nil, false + } + + n, err := sess.backend.Read(buf) + if err != nil { + if netutil.IsTimeout(err) { + if sess.idleDuration() > r.sessionTTL { + return nil, false + } + return nil, true + } + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err) + } + return nil, false + } + + return buf[:n], true +} + +// cleanupLoop periodically removes idle sessions. +func (r *Relay) cleanupLoop() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.cleanupIdleSessions() + } + } +} + +// cleanupIdleSessions closes sessions that have been idle for too long. +func (r *Relay) cleanupIdleSessions() { + var expired []*session + + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + continue + } + idle := sess.idleDuration() + if idle > r.sessionTTL { + r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load()) + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close idle session %s backend: %v", sess.addr, err) + } + expired = append(expired, sess) + } + } + r.mu.Unlock() + + obs := r.getObserver() + for _, sess := range expired { + if obs != nil { + obs.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// removeSession removes a session from the map if it still matches the +// given pointer. This is safe to call concurrently with cleanupIdleSessions +// because the identity check prevents double-close when both paths race. +func (r *Relay) removeSession(sess *session) { + r.mu.Lock() + key := clientAddr(sess.addr.String()) + removed := r.sessions[key] == sess + if removed { + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + } + r.mu.Unlock() + + if removed { + r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// logSessionEnd sends an access log entry for a completed UDP session. +func (r *Relay) logSessionEnd(sess *session) { + if r.accessLog == nil { + return + } + + var sourceIP netip.Addr + if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + + r.accessLog.LogL4(accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: sourceIP, + DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(), + BytesUpload: sess.bytesIn.Load(), + BytesDownload: sess.bytesOut.Load(), + }) +} + +// logDeny sends an access log entry for a denied UDP packet. +func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict, observeOnly bool) { + if r.accessLog == nil { + return + } + + entry := accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: clientIP, + DenyReason: verdict.String(), + } + if verdict.IsCrowdSec() { + entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()} + if observeOnly { + entry.Metadata["crowdsec_mode"] = "observe" + entry.DenyReason = "" + } + } + r.accessLog.LogL4(entry) +} + +// Close stops the relay, waits for all session goroutines to exit, +// and cleans up remaining sessions. +func (r *Relay) Close() { + r.cancel() + if err := r.listener.Close(); err != nil { + r.logger.Debugf("close UDP listener: %v", err) + } + + var closedSessions []*session + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + delete(r.sessions, key) + continue + } + r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + delete(r.sessions, key) + closedSessions = append(closedSessions, sess) + } + r.mu.Unlock() + + obs := r.getObserver() + for _, sess := range closedSessions { + if obs != nil { + obs.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } + + r.sessWg.Wait() +} + +// addrFromUDPAddr extracts a netip.Addr from a net.Addr. +func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) { + ap, err := netip.ParseAddrPort(addr.String()) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} diff --git a/proxy/internal/udp/relay_test.go b/proxy/internal/udp/relay_test.go new file mode 100644 index 000000000..a1e91b290 --- /dev/null +++ b/proxy/internal/udp/relay_test.go @@ -0,0 +1,493 @@ +package udp + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRelay_BasicPacketExchange(t *testing.T) { + // Set up a UDP backend that echoes packets. + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + // Set up the relay's public-facing listener. + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + backendAddr := backend.LocalAddr().String() + + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a client and send a packet to the relay. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + testData := []byte("hello UDP relay") + _, err = client.Write(testData) + require.NoError(t, err) + + // Read the echoed response. + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed packet") +} + +func TestRelay_MultipleClients(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Two clients, each should get their own session. + for i, msg := range []string{"client-1", "client-2"} { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte(msg)) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err, "client %d read", i) + assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i) + } + + // Verify two sessions were created. + relay.mu.RLock() + sessionCount := len(relay.sessions) + relay.mu.RUnlock() + assert.Equal(t, 2, sessionCount, "should have two sessions") +} + +func TestRelay_Close(t *testing.T) { + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc}) + + done := make(chan struct{}) + go func() { + relay.Serve() + close(done) + }() + + relay.Close() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Serve did not return after Close") + } +} + +func TestRelay_SessionCleanup(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a session. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + + // Verify session exists. + relay.mu.RLock() + assert.Equal(t, 1, len(relay.sessions)) + relay.mu.RUnlock() + + // Make session appear idle by setting lastSeen to the past. + relay.mu.Lock() + for _, sess := range relay.sessions { + sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano()) + } + relay.mu.Unlock() + + // Trigger cleanup manually. + relay.cleanupIdleSessions() + + relay.mu.RLock() + assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up") + relay.mu.RUnlock() +} + +// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new +// one on the same port works cleanly (simulates port mapping modify cycle). +func TestRelay_CloseAndRecreate(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // First relay. + ln1, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay1.Serve() + + client1, err := net.Dial("udp", ln1.LocalAddr().String()) + require.NoError(t, err) + _, err = client1.Write([]byte("relay1")) + require.NoError(t, err) + require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + n, err := client1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay1", string(buf[:n])) + client1.Close() + + // Close first relay. + relay1.Close() + + // Second relay on same port. + port := ln1.LocalAddr().(*net.UDPAddr).Port + ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + + relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay2.Serve() + defer relay2.Close() + + client2, err := net.Dial("udp", ln2.LocalAddr().String()) + require.NoError(t, err) + defer client2.Close() + _, err = client2.Write([]byte("relay2")) + require.NoError(t, err) + require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second))) + n, err = client2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port") +} + +func TestRelay_SessionLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // Create a relay with a max of 2 sessions. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2}) + go relay.Serve() + defer relay.Close() + + // Create 2 clients to fill up the session limit. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err, "client %d should get response", i) + } + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions") + relay.mu.RUnlock() + + // Third client should get its packet dropped (session creation fails). + client3, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client3.Close() + + _, err = client3.Write([]byte("should be dropped")) + require.NoError(t, err) + + require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + buf := make([]byte, 1024) + _, err = client3.Read(buf) + assert.Error(t, err, "third client should time out because session was rejected") + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit") + relay.mu.RUnlock() +} + +// testObserver records UDP session lifecycle events for test assertions. +type testObserver struct { + mu sync.Mutex + started int + ended int + rejected int + dialErr int + packets int + bytes int +} + +func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() } +func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() } +func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() } +func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() } +func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) { + o.mu.Lock() + o.packets++ + o.bytes += b + o.mu.Unlock() +} + +func TestRelay_CloseFiresObserverEnded(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc}) + relay.SetObserver(obs) + go relay.Serve() + + // Create two sessions. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + } + + obs.mu.Lock() + assert.Equal(t, 2, obs.started, "should have 2 started events") + obs.mu.Unlock() + + // Close should fire UDPSessionEnded for all remaining sessions. + relay.Close() + + obs.mu.Lock() + assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session") + obs.mu.Unlock() +} + +func TestRelay_SessionRateLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + // High max sessions (1000) but the relay uses a rate limiter internally + // (default: 50/s burst 100). We exhaust the burst by creating sessions + // rapidly, then verify that subsequent creates are rejected. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000}) + relay.SetObserver(obs) + go relay.Serve() + defer relay.Close() + + // Exhaust the burst by calling getOrCreateSession directly with + // synthetic addresses. This is faster than real UDP round-trips. + for i := range sessionCreateBurst + 20 { + addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i} + _, _ = relay.getOrCreateSession(addr) + } + + obs.mu.Lock() + rejected := obs.rejected + obs.mu.Unlock() + + assert.Greater(t, rejected, 0, "some sessions should be rate-limited") +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 1163c50f4..4b1ecf922 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -18,9 +18,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -37,7 +39,7 @@ type integrationTestSetup struct { grpcServer *grpc.Server grpcAddr string cleanup func() - services []*reverseproxy.Service + services []*service.Service } func setupIntegrationTest(t *testing.T) *integrationTestSetup { @@ -66,13 +68,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { privKey := base64.StdEncoding.EncodeToString(priv) // Create test services in the store - services := []*reverseproxy.Service{ + services := []*service.Service{ { ID: "rp-1", AccountID: "test-account-1", Name: "Test App 1", Domain: "app1.test.proxy.io", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, @@ -91,7 +93,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { AccountID: "test-account-1", Name: "Test App 2", Domain: "app2.test.proxy.io", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, @@ -112,7 +114,11 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) // Create real users manager usersManager := users.NewManager(testStore) @@ -124,17 +130,24 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { HMACKey: []byte("test-hmac-key"), } + proxyManager := &testProxyManager{} + proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager, + proxyManager, ) // Use store-backed service manager svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} - proxyService.SetProxyManager(svcMgr) + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) // Start real gRPC server lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -185,25 +198,91 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, return nil, 0, nil } +// testProxyManager is a mock implementation of proxy.Manager for testing. +type testProxyManager struct{} + +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { + return nil +} + +func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { + return nil +} + +func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) { + return nil, nil +} + +func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { + return nil, nil +} + +func (m *testProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { + return nil +} + +// testProxyController is a mock implementation of rpservice.ProxyController for testing. +type testProxyController struct{} + +func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) { + // noop +} + +func (c *testProxyController) GetOIDCValidationConfig() nbproxy.OIDCValidationConfig { + return nbproxy.OIDCValidationConfig{} +} + +func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error { + return nil +} + +func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error { + return nil +} + +func (c *testProxyController) GetProxiesForCluster(_ string) []string { + return nil +} + // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store tokenStore *nbgrpc.OneTimeTokenStore } -func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + +func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } -func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) } -func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, errors.New("not implemented") } -func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, errors.New("not implemented") } @@ -215,7 +294,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, return nil } -func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { +func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error { return nil } @@ -227,15 +306,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID return nil } -func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1") } -func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) } -func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } @@ -243,6 +322,24 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, return "", nil } +func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return &service.ExposeServiceResponse{}, nil +} + +func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} + +func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func strPtr(s string) *string { return &s } @@ -410,7 +507,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T logger := log.New() logger.SetLevel(log.WarnLevel) - authMw := auth.NewMiddleware(logger, nil) + authMw := auth.NewMiddleware(logger, nil, nil) proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger) clusterAddress := "test.proxy.io" @@ -429,15 +526,16 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T nil, "", 0, - mapping.GetAccountId(), - mapping.GetId(), + proxytypes.AccountID(mapping.GetAccountId()), + proxytypes.ServiceID(mapping.GetId()), + nil, ) require.NoError(t, err) // Apply to real proxy (idempotent) proxyHandler.AddMapping(proxy.Mapping{ Host: mapping.GetDomain(), - ID: mapping.GetId(), + ID: proxytypes.ServiceID(mapping.GetId()), AccountID: proxytypes.AccountID(mapping.GetAccountId()), }) } diff --git a/proxy/server.go b/proxy/server.go index 60811e53b..fbd0d058e 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -19,14 +19,18 @@ import ( "net/netip" "net/url" "path/filepath" + "reflect" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/pires/go-proxyproto" - "github.com/prometheus/client_golang/prometheus" + prometheus2 "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/sdk/metric" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -38,20 +42,33 @@ import ( "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/conntrack" + "github.com/netbirdio/netbird/proxy/internal/crowdsec" "github.com/netbirdio/netbird/proxy/internal/debug" + "github.com/netbirdio/netbird/proxy/internal/geolocation" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" - "github.com/netbirdio/netbird/proxy/internal/metrics" + proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" + udprelay "github.com/netbirdio/netbird/proxy/internal/udp" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) +// portRouter bundles a per-port Router with its listener and cancel func. +type portRouter struct { + router *nbtcp.Router + listener net.Listener + cancel context.CancelFunc +} + type Server struct { mgmtClient proto.ProxyServiceClient proxy *proxy.ReverseProxy @@ -63,12 +80,37 @@ type Server struct { debug *http.Server healthServer *health.Server healthChecker *health.Checker - meter *metrics.Metrics + meter *proxymetrics.Metrics + accessLog *accesslog.Logger + mainRouter *nbtcp.Router + mainPort uint16 + udpMu sync.Mutex + udpRelays map[types.ServiceID]*udprelay.Relay + udpRelayWg sync.WaitGroup + portMu sync.RWMutex + portRouters map[uint16]*portRouter + svcPorts map[types.ServiceID][]uint16 + lastMappings map[types.ServiceID]*proto.ProxyMapping + portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker + // geo resolves IP addresses to country/city for access restrictions and access logs. + geo restrict.GeoResolver + geoRaw *geolocation.Lookup + + // crowdsecRegistry manages the shared CrowdSec bouncer lifecycle. + crowdsecRegistry *crowdsec.Registry + // crowdsecServices tracks which services have CrowdSec enabled for + // proper acquire/release lifecycle management. + crowdsecMu sync.Mutex + crowdsecServices map[types.ServiceID]bool + + // routerReady is closed once mainRouter is fully initialized. + // The mapping worker waits on this before processing updates. + routerReady chan struct{} // Mostly used for debugging on management. startTime time.Time @@ -84,12 +126,21 @@ type Server struct { GenerateACMECertificates bool ACMEChallengeAddress string ACMEDirectory string + // ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL). + ACMEEABKID string + // ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB. + ACMEEABHMACKey string // ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01". // Defaults to "tls-alpn-01" if not specified. ACMEChallengeType string // CertLockMethod controls how ACME certificate locks are coordinated // across replicas. Default: CertLockAuto (detect environment). CertLockMethod acme.CertLockMethod + // WildcardCertDir is an optional directory containing wildcard certificate + // pairs (.crt / .key). Wildcard patterns are extracted from + // the certificates' SAN lists. Matching domains use these static certs + // instead of ACME. + WildcardCertDir string // DebugEndpointEnabled enables the debug HTTP endpoint. DebugEndpointEnabled bool @@ -106,26 +157,72 @@ type Server struct { // When set, forwarding headers from these sources are preserved and // appended to instead of being stripped. TrustedProxies []netip.Prefix - // WireguardPort is the port for the WireGuard interface. Use 0 for a - // random OS-assigned port. A fixed port only works with single-account - // deployments; multiple accounts will fail to bind the same port. - WireguardPort int + // WireguardPort is the port for the NetBird tunnel interface. Use 0 + // for a random OS-assigned port. A fixed port only works with + // single-account deployments; multiple accounts will fail to bind + // the same port. + WireguardPort uint16 // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool + // PreSharedKey used for tunnel between proxy and peers (set globally not per account) + PreSharedKey string + // SupportsCustomPorts indicates whether the proxy can bind arbitrary + // ports for TCP/UDP/TLS services. + SupportsCustomPorts bool + // RequireSubdomain indicates whether a subdomain label is required + // in front of this proxy's cluster domain. When true, accounts cannot + // create services on the bare cluster domain. + RequireSubdomain bool + // MaxDialTimeout caps the per-service backend dial timeout. + // When the API sends a timeout, it is clamped to this value. + // When the API sends no timeout, this value is used as the default. + // Zero means no cap (the proxy honors whatever management sends). + MaxDialTimeout time.Duration + // GeoDataDir is the directory containing GeoLite2 MMDB files for + // country-based access restrictions. Empty disables geo lookups. + GeoDataDir string + // CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec. + CrowdSecAPIURL string + // CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables CrowdSec. + CrowdSecAPIKey string + // MaxSessionIdleTimeout caps the per-service session idle timeout. + // Zero means no cap (the proxy honors whatever management sends). + // Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments. + MaxSessionIdleTimeout time.Duration } -// NotifyStatus sends a status update to management about tunnel connectivity -func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error { +// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured. +func (s *Server) clampIdleTimeout(d time.Duration) time.Duration { + if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout { + return s.MaxSessionIdleTimeout + } + return d +} + +// clampDialTimeout returns d capped to MaxDialTimeout when configured. +// If d is zero, MaxDialTimeout is used as the default. +func (s *Server) clampDialTimeout(d time.Duration) time.Duration { + if s.MaxDialTimeout <= 0 { + return d + } + if d <= 0 || d > s.MaxDialTimeout { + return s.MaxDialTimeout + } + return d +} + +// NotifyStatus sends a status update to management about tunnel connectivity. +func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED if connected { status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: status, CertificateIssued: false, }) @@ -133,10 +230,10 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain } // NotifyCertificateIssued sends a notification to management that a certificate was issued -func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error { +func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error { _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, }) @@ -145,9 +242,25 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() + s.routerReady = make(chan struct{}) + s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) + s.portRouters = make(map[uint16]*portRouter) + s.svcPorts = make(map[types.ServiceID][]uint16) + s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) - reg := prometheus.NewRegistry() - s.meter = metrics.New(reg) + exporter, err := prometheus.New() + if err != nil { + return fmt.Errorf("create prometheus exporter: %w", err) + } + + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + pkg := reflect.TypeOf(Server{}).PkgPath() + meter := provider.Meter(pkg) + + s.meter, err = proxymetrics.New(ctx, meter) + if err != nil { + return fmt.Errorf("create metrics: %w", err) + } mgmtConn, err := s.dialManagement() if err != nil { @@ -159,11 +272,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) - go s.newManagementMappingWorker(ctx, s.mgmtClient) + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() // Initialize the netbird client, this is required to build peer connections // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) + s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ + MgmtAddr: s.ManagementAddress, + WGPort: s.WireguardPort, + PreSharedKey: s.PreSharedKey, + }, s.Logger, s, s.mgmtClient) + + // Create health checker before the mapping worker so it can track + // management connectivity from the first stream connection. + s.healthChecker = health.NewChecker(s.Logger, s.netbird) + + s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger)) + s.crowdsecServices = make(map[types.ServiceID]bool) + + go s.newManagementMappingWorker(runCtx, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) if err != nil { @@ -173,17 +300,36 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) + geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir) + if err != nil { + return fmt.Errorf("initialize geolocation: %w", err) + } + s.geoRaw = geoLookup + if geoLookup != nil { + s.geo = geoLookup + } + + var startupOK bool + defer func() { + if startupOK { + return + } + if s.geoRaw != nil { + if err := s.geoRaw.Close(); err != nil { + s.Logger.Debugf("close geolocation on startup failure: %v", err) + } + } + }() + // Configure the authentication middleware with session validator for OIDC group checks. - s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) + s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo) // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) - - s.healthChecker = health.NewChecker(s.Logger, s.netbird) + s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.startDebugEndpoint() - if err := s.startHealthServer(reg); err != nil { + if err := s.startHealthServer(); err != nil { return err } @@ -191,18 +337,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { handler := http.Handler(s.proxy) handler = s.auth.Protect(handler) handler = web.AssetHandler(handler) - handler = accessLog.Middleware(handler) + handler = s.accessLog.Middleware(handler) handler = s.meter.Middleware(handler) handler = s.hijackTracker.Middleware(handler) - // Start the reverse proxy HTTPS server. - s.https = &http.Server{ - Addr: addr, - Handler: handler, - TLSConfig: tlsConfig, - ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), - } - + // Start a raw TCP listener; the SNI router peeks at ClientHello + // and routes to either the HTTP handler or a TCP relay. lc := net.ListenConfig{} ln, err := lc.Listen(ctx, "tcp", addr) if err != nil { @@ -211,11 +351,36 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } + s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid + + // Set up the SNI router for TCP/HTTP multiplexing on the main port. + s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) + s.mainRouter.SetObserver(s.meter) + s.mainRouter.SetAccessLogger(s.accessLog) + close(s.routerReady) + + // The HTTP server uses the chanListener fed by the SNI router. + s.https = &http.Server{ + Addr: addr, + Handler: handler, + TLSConfig: tlsConfig, + ReadHeaderTimeout: httpReadHeaderTimeout, + IdleTimeout: httpIdleTimeout, + ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), + } + + startupOK = true httpsErr := make(chan error, 1) go func() { - s.Logger.Debugf("starting reverse proxy server on %s", addr) - httpsErr <- s.https.ServeTLS(ln, "", "") + s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") + httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") + }() + + routerErr := make(chan error, 1) + go func() { + s.Logger.Debugf("starting SNI router on %s", addr) + routerErr <- s.mainRouter.Serve(runCtx, ln) }() select { @@ -225,6 +390,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return fmt.Errorf("https server: %w", err) } return nil + case err := <-routerErr: + s.shutdownServices() + if err != nil { + return fmt.Errorf("SNI router: %w", err) + } + return nil case <-ctx.Done(): s.gracefulShutdown() return nil @@ -274,12 +445,12 @@ func (s *Server) startDebugEndpoint() { } // startHealthServer launches the health probe and metrics server. -func (s *Server) startHealthServer(reg *prometheus.Registry) error { +func (s *Server) startHealthServer() error { healthAddr := s.HealthAddress if healthAddr == "" { healthAddr = defaultHealthAddr } - s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true})) healthListener, err := net.Listen("tcp", healthAddr) if err != nil { return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) @@ -352,6 +523,13 @@ const ( // shutdownServiceTimeout is the maximum time to wait for auxiliary // services (health probe, debug endpoint, ACME) to shut down. shutdownServiceTimeout = 5 * time.Second + + // httpReadHeaderTimeout limits how long the server waits to read + // request headers after accepting a connection. Prevents slowloris. + httpReadHeaderTimeout = 10 * time.Second + // httpIdleTimeout limits how long an idle keep-alive connection + // stays open before the server closes it. + httpIdleTimeout = 120 * time.Second ) func (s *Server) dialManagement() (*grpc.ClientConn, error) { @@ -413,7 +591,20 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) + var err error + s.acme, err = acme.NewManager(acme.ManagerConfig{ + CertDir: s.CertificateDirectory, + ACMEURL: s.ACMEDirectory, + EABKID: s.ACMEEABKID, + EABHMACKey: s.ACMEEABHMACKey, + LockMethod: s.CertLockMethod, + WildcardDir: s.WildcardCertDir, + }, s, s.Logger, s.meter) + if err != nil { + return nil, fmt.Errorf("create ACME manager: %w", err) + } + + go s.acme.WatchWildcards(ctx) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{ @@ -429,6 +620,10 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { } tlsConfig = s.acme.TLSConfig() + // autocert.Manager.TLSConfig() wires its own GetCertificate, which + // bypasses our override that checks wildcards first. + tlsConfig.GetCertificate = s.acme.GetCertificate + // ServerName needs to be set to allow for ACME to work correctly // when using CNAME URLs to access the proxy. tlsConfig.ServerName = s.ProxyURL @@ -472,6 +667,9 @@ func (s *Server) gracefulShutdown() { s.Logger.Infof("closed %d hijacked connection(s)", n) } + // Drain all router relay connections (main + per-port) in parallel. + s.drainAllRouters(shutdownDrainTimeout) + // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") @@ -479,6 +677,34 @@ func (s *Server) gracefulShutdown() { // shutdownServices stops all background services concurrently and waits for // them to finish. +// drainAllRouters drains active relay connections on the main router and +// all per-port routers in parallel, up to the given timeout. +func (s *Server) drainAllRouters(timeout time.Duration) { + var wg sync.WaitGroup + + drain := func(name string, router *nbtcp.Router) { + wg.Add(1) + go func() { + defer wg.Done() + if ok := router.Drain(timeout); !ok { + s.Logger.Warnf("timed out draining %s relay connections", name) + } + }() + } + + if s.mainRouter != nil { + drain("main router", s.mainRouter) + } + + s.portMu.RLock() + for port, pr := range s.portRouters { + drain(fmt.Sprintf("port %d", port), pr.router) + } + s.portMu.RUnlock() + + wg.Wait() +} + func (s *Server) shutdownServices() { var wg sync.WaitGroup @@ -516,7 +742,189 @@ func (s *Server) shutdownServices() { }() } + // Close all UDP relays and wait for their goroutines to exit. + s.udpMu.Lock() + for id, relay := range s.udpRelays { + relay.Close() + delete(s.udpRelays, id) + } + s.udpMu.Unlock() + s.udpRelayWg.Wait() + + // Close all per-port routers. + s.portMu.Lock() + for port, pr := range s.portRouters { + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + } + maps.Clear(s.svcPorts) + maps.Clear(s.lastMappings) + s.portMu.Unlock() + + // Wait for per-port router serve goroutines to exit. + s.portRouterWg.Wait() + wg.Wait() + + if s.accessLog != nil { + s.accessLog.Close() + } + + if s.geoRaw != nil { + if err := s.geoRaw.Close(); err != nil { + s.Logger.Debugf("close geolocation: %v", err) + } + } + + s.shutdownCrowdSec() +} + +func (s *Server) shutdownCrowdSec() { + if s.crowdsecRegistry == nil { + return + } + s.crowdsecMu.Lock() + services := maps.Clone(s.crowdsecServices) + maps.Clear(s.crowdsecServices) + s.crowdsecMu.Unlock() + + for svcID := range services { + s.crowdsecRegistry.Release(svcID) + } +} + +// resolveDialFunc returns a DialContextFunc that dials through the +// NetBird tunnel for the given account. +func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) { + client, ok := s.netbird.GetClient(accountID) + if !ok { + return nil, fmt.Errorf("no client for account %s", accountID) + } + return client.DialContext, nil +} + +// notifyError reports a resource error back to management so it can be +// surfaced to the user (e.g. port bind failure, dialer resolution error). +func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { + s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err) +} + +// sendStatusUpdate sends a status update for a service to management. +func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) { + req := &proto.SendStatusUpdateRequest{ + ServiceId: string(serviceID), + AccountId: string(accountID), + Status: st, + } + if err != nil { + msg := err.Error() + req.ErrorMessage = &msg + } + if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil { + s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr) + } +} + +// routerForPort returns the router that handles the given listen port. If port +// is 0 or matches the main listener port, the main router is returned. +// Otherwise a new per-port router is created and started. +func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) { + if port == 0 || port == s.mainPort { + return s.mainRouter, nil + } + return s.getOrCreatePortRouter(ctx, port) +} + +// routerForPortExisting returns the router for the given port without creating +// one. Returns the main router for port 0 / mainPort, or nil if no per-port +// router exists. +func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router { + if port == 0 || port == s.mainPort { + return s.mainRouter + } + s.portMu.RLock() + pr := s.portRouters[port] + s.portMu.RUnlock() + if pr != nil { + return pr.router + } + return nil +} + +// getOrCreatePortRouter returns an existing per-port router or creates one +// with a new TCP listener and starts serving. +func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) { + s.portMu.Lock() + defer s.portMu.Unlock() + + if pr, ok := s.portRouters[port]; ok { + return pr.router, nil + } + + listenAddr := fmt.Sprintf(":%d", port) + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err) + } + if s.ProxyProtocol { + ln = s.wrapProxyProtocol(ln) + } + + router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) + router.SetObserver(s.meter) + router.SetAccessLogger(s.accessLog) + portCtx, cancel := context.WithCancel(ctx) + + s.portRouters[port] = &portRouter{ + router: router, + listener: ln, + cancel: cancel, + } + + s.portRouterWg.Add(1) + go func() { + defer s.portRouterWg.Done() + if err := router.Serve(portCtx, ln); err != nil { + s.Logger.Debugf("port %d router stopped: %v", port, err) + } + }() + + s.Logger.Debugf("started per-port router on %s", listenAddr) + return router, nil +} + +// cleanupPortIfEmpty tears down a per-port router if it has no remaining +// routes or fallback. The main port is never cleaned up. Active relay +// connections are drained before the listener is closed. +func (s *Server) cleanupPortIfEmpty(port uint16) { + if port == 0 || port == s.mainPort { + return + } + + s.portMu.Lock() + pr, ok := s.portRouters[port] + if !ok || !pr.router.IsEmpty() { + s.portMu.Unlock() + return + } + + // Cancel and close the listener while holding the lock so that + // getOrCreatePortRouter sees the entry is gone before we drain. + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + s.portMu.Unlock() + + // Drain active relay connections outside the lock. + if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok { + s.Logger.Warnf("timed out draining relay connections on port %d", port) + } + s.Logger.Debugf("cleaned up empty per-port router on port %d", port) } func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { @@ -539,11 +947,17 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr s.healthChecker.SetManagementConnected(false) } + supportsCrowdSec := s.crowdsecRegistry.Available() mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ ProxyId: s.ID, Version: s.Version, StartedAt: timestamppb.New(s.startTime), Address: s.ProxyURL, + Capabilities: &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + RequireSubdomain: &s.RequireSubdomain, + SupportsCrowdsec: &supportsCrowdSec, + }, }) if err != nil { return fmt.Errorf("create mapping stream: %w", err) @@ -580,6 +994,12 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + for { // Check for context completion to gracefully shutdown. select { @@ -616,25 +1036,28 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), - "path": mapping.GetPath(), + "mode": mapping.GetMode(), + "port": mapping.GetListenPort(), "id": mapping.GetId(), }).Debug("Processing mapping update") switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { - // TODO: Retry this? Or maybe notify the management server that this mapping has failed? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("Error adding new mapping, ignoring this mapping and continuing processing") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: - if err := s.updateMapping(ctx, mapping); err != nil { + if err := s.modifyMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), - }).Errorf("failed to update mapping: %v", err) + "error": err, + }).Error("failed to modify mapping") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: s.removeMapping(ctx, mapping) @@ -642,26 +1065,412 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } +// addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - serviceID := mapping.GetId() + svcID := types.ServiceID(mapping.GetId()) authToken := mapping.GetAuthToken() - if err := s.netbird.AddPeer(ctx, accountID, d, authToken, serviceID); err != nil { - return fmt.Errorf("create peer for domain %q: %w", d, err) - } - if s.acme != nil { - s.acme.AddDomain(d, string(accountID), serviceID) + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil { + return fmt.Errorf("create peer for service %s: %w", svcID, err) } - // Pass the mapping through to the update function to avoid duplicating the - // setup, currently update is simply a subset of this function, so this - // separation makes sense...to me at least. + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil { + s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure") + } + return err + } + s.storeMapping(mapping) + return nil +} + +// modifyMapping updates a service mapping in place without tearing down the +// NetBird peer. It cleans up old routes using the previously stored mapping +// state and re-applies them from the new mapping. +func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) + } + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + return err + } + s.storeMapping(mapping) + return nil +} + +// setupMappingRoutes configures the appropriate routes or relays for the given +// service mapping based on its mode. The NetBird peer must already exist. +func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP: + return s.setupTCPMapping(ctx, mapping) + case types.ServiceModeUDP: + return s.setupUDPMapping(ctx, mapping) + case types.ServiceModeTLS: + return s.setupTLSMapping(ctx, mapping) + default: + return s.setupHTTPMapping(ctx, mapping) + } +} + +// setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes. +func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + + if len(mapping.GetPath()) == 0 { + return nil + } + + var wildcardHit bool + if s.acme != nil { + wildcardHit = s.acme.AddDomain(d, accountID, svcID) + } + s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteHTTP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + }) if err := s.updateMapping(ctx, mapping); err != nil { - s.removeMapping(ctx, mapping) return fmt.Errorf("update mapping for domain %q: %w", d, err) } + + if wildcardHit { + if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil { + s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) + } + } + + return nil +} + +// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port. +func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TCP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TCP service %s", svcID) + } + + if s.WireguardPort != 0 && port == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", port) + } + + router, err := s.routerForPort(ctx, port) + if err != nil { + return fmt.Errorf("router for TCP port %d: %w", port, err) + } + + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + + router.SetGeo(s.geo) + router.SetFallback(nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTCP, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), + Filter: s.parseRestrictions(mapping), + }) + + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{port} + s.portMu.Unlock() + + s.meter.L4ServiceAdded(types.ServiceModeTCP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupUDPMapping starts a UDP relay on the listen port. +func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("UDP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for UDP service %s", svcID) + } + + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + + if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { + return fmt.Errorf("UDP relay for service %s: %w", svcID, err) + } + + s.meter.L4ServiceAdded(types.ServiceModeUDP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupTLSMapping configures a TLS SNI-routed passthrough on the listen port. +func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + tlsPort, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TLS service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TLS service %s", svcID) + } + + if s.WireguardPort != 0 && tlsPort == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", tlsPort) + } + + router, err := s.routerForPort(ctx, tlsPort) + if err != nil { + return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) + } + + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + + router.SetGeo(s.geo) + router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTLS, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), + Filter: s.parseRestrictions(mapping), + }) + + if tlsPort != s.mainPort { + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{tlsPort} + s.portMu.Unlock() + } + + s.Logger.WithFields(log.Fields{ + "domain": mapping.GetDomain(), + "target": targetAddr, + "port": tlsPort, + "service": svcID, + }).Info("TLS passthrough mapping added") + + s.meter.L4ServiceAdded(types.ServiceModeTLS) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// serviceKeyForMapping returns the appropriate ServiceKey for a mapping. +// TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key. +func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP, types.ServiceModeUDP: + return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId())) + default: + return roundtrip.DomainServiceKey(mapping.GetDomain()) + } +} + +// parseRestrictions converts a proto mapping's access restrictions into +// a restrict.Filter. Returns nil if the mapping has no restrictions. +func (s *Server) parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter { + r := mapping.GetAccessRestrictions() + if r == nil { + return nil + } + + svcID := types.ServiceID(mapping.GetId()) + csMode := restrict.CrowdSecMode(r.GetCrowdsecMode()) + + var checker restrict.CrowdSecChecker + if csMode == restrict.CrowdSecEnforce || csMode == restrict.CrowdSecObserve { + if b := s.crowdsecRegistry.Acquire(svcID); b != nil { + checker = b + s.crowdsecMu.Lock() + s.crowdsecServices[svcID] = true + s.crowdsecMu.Unlock() + } else { + s.Logger.Warnf("service %s requests CrowdSec mode %q but proxy has no CrowdSec configured", svcID, csMode) + // Keep the mode: restrict.Filter will fail-closed for enforce (DenyCrowdSecUnavailable) + // and allow for observe. + } + } + + return restrict.ParseFilter(restrict.FilterConfig{ + AllowedCIDRs: r.GetAllowedCidrs(), + BlockedCIDRs: r.GetBlockedCidrs(), + AllowedCountries: r.GetAllowedCountries(), + BlockedCountries: r.GetBlockedCountries(), + CrowdSec: checker, + CrowdSecMode: csMode, + Logger: log.NewEntry(s.Logger), + }) +} + +// releaseCrowdSec releases the CrowdSec bouncer reference for the given +// service if it had one. +func (s *Server) releaseCrowdSec(svcID types.ServiceID) { + s.crowdsecMu.Lock() + had := s.crowdsecServices[svcID] + delete(s.crowdsecServices, svcID) + s.crowdsecMu.Unlock() + + if had { + s.crowdsecRegistry.Release(svcID) + } +} + +// warnIfGeoUnavailable logs a warning if the mapping has country restrictions +// but the proxy has no geolocation database loaded. All requests to this +// service will be denied at runtime (fail-close). +func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) { + if r == nil { + return + } + if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 { + return + } + if s.geo != nil && s.geo.Available() { + return + } + s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain) +} + +// l4TargetAddress extracts and validates the target address from a mapping's +// first path entry. Returns empty string if no paths exist or the address is +// not a valid host:port. +func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string { + paths := mapping.GetPath() + if len(paths) == 0 { + return "" + } + target := paths[0].GetTarget() + if _, _, err := net.SplitHostPort(target); err != nil { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "target": target, + }).Warnf("invalid L4 target address: %v", err) + return "" + } + return target +} + +// l4ProxyProtocol returns whether the first target has PROXY protocol enabled. +func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { + paths := mapping.GetPath() + if len(paths) == 0 { + return false + } + return paths[0].GetOptions().GetProxyProtocol() +} + +// l4DialTimeout returns the dial timeout from the first target's options, +// clamped to MaxDialTimeout. +func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { + return s.clampDialTimeout(d.AsDuration()) + } + } + return s.clampDialTimeout(0) +} + +// l4SessionIdleTimeout returns the configured session idle timeout from the +// mapping options, or 0 to use the relay's default. +func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil { + return d.AsDuration() + } + } + return 0 +} + +// addUDPRelay starts a UDP relay on the specified listen port. +func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + if s.WireguardPort != 0 && listenPort == s.WireguardPort { + return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort) + } + + // Close existing relay if present (idempotent re-add). + s.removeUDPRelay(svcID) + + listenAddr := fmt.Sprintf(":%d", listenPort) + + listener, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return fmt.Errorf("listen UDP on %s: %w", listenAddr, err) + } + + dialFn, err := s.resolveDialFunc(accountID) + if err != nil { + if err := listener.Close(); err != nil { + s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err) + } + return fmt.Errorf("resolve dialer for UDP: %w", err) + } + + entry := s.Logger.WithFields(log.Fields{ + "target": targetAddress, + "listen_port": listenPort, + "service_id": svcID, + }) + + relay := udprelay.New(ctx, udprelay.RelayConfig{ + Logger: entry, + Listener: listener, + Target: targetAddress, + Domain: mapping.GetDomain(), + AccountID: accountID, + ServiceID: svcID, + DialFunc: dialFn, + DialTimeout: s.l4DialTimeout(mapping), + SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), + AccessLog: s.accessLog, + Filter: s.parseRestrictions(mapping), + Geo: s.geo, + }) + relay.SetObserver(s.meter) + + s.udpMu.Lock() + s.udpRelays[svcID] = relay + s.udpMu.Unlock() + + s.udpRelayWg.Go(relay.Serve) + entry.Info("UDP relay added") return nil } @@ -671,50 +1480,151 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) // the auth and proxy mappings. // Note: this does require the management server to always send a // full mapping rather than deltas during a modification. + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + var schemes []auth.Scheme if mapping.GetAuth().GetPassword() { - schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetPin() { - schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetOidc() { - schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), s.ForwardedProto)) + schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) + } + for _, ha := range mapping.GetAuth().GetHeaderAuths() { + schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader())) } + ipRestrictions := s.parseRestrictions(mapping) + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } - s.proxy.AddMapping(s.protoToMapping(mapping)) - s.meter.AddMapping(s.protoToMapping(mapping)) + m := s.protoToMapping(ctx, mapping) + s.proxy.AddMapping(m) + s.meter.AddMapping(m) return nil } +// removeMapping tears down routes/relays and the NetBird peer for a service. +// Uses the stored mapping state when available to ensure all previously +// configured routes are cleaned up. func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil { + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil { s.Logger.WithFields(log.Fields{ "account_id": accountID, - "domain": d, + "service_id": mapping.GetId(), "error": err, - }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") + }).Error("failed to remove NetBird peer, continuing cleanup") } - if s.acme != nil { - s.acme.RemoveDomain(d) + + if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) } - s.auth.RemoveDomain(mapping.GetDomain()) - s.proxy.RemoveMapping(s.protoToMapping(mapping)) - s.meter.RemoveMapping(s.protoToMapping(mapping)) } -func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { - paths := make(map[string]*url.URL) +// cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a +// service without touching the NetBird peer. This is used for both full +// removal and in-place modification of mappings. +func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { + svcID := types.ServiceID(mapping.GetId()) + host := mapping.GetDomain() + + // HTTP/TLS cleanup (only relevant when a domain is set). + if host != "" { + d := domain.Domain(host) + if s.acme != nil { + s.acme.RemoveDomain(d) + } + s.auth.RemoveDomain(host) + if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) { + s.meter.RemoveMapping(proxy.Mapping{Host: host}) + } + // Close hijacked connections (WebSocket) for this domain. + if n := s.hijackTracker.CloseByHost(host); n > 0 { + s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host) + } + // Remove SNI route from the main router (covers both HTTP and main-port TLS). + s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) + } + + // Extract and delete tracked custom-port entries atomically. + s.portMu.Lock() + entries := s.svcPorts[svcID] + delete(s.svcPorts, svcID) + s.portMu.Unlock() + + for _, entry := range entries { + if router := s.routerForPortExisting(entry); router != nil { + if host != "" { + router.RemoveRoute(nbtcp.SNIHost(host), svcID) + } else { + router.RemoveFallback(svcID) + } + } + s.cleanupPortIfEmpty(entry) + } + + // UDP relay cleanup (idempotent). + s.removeUDPRelay(svcID) + + // Release CrowdSec after all routes are removed so the shared bouncer + // isn't stopped while stale filters can still be reached by in-flight requests. + s.releaseCrowdSec(svcID) +} + +// removeUDPRelay stops and removes a UDP relay by service ID. +func (s *Server) removeUDPRelay(svcID types.ServiceID) { + s.udpMu.Lock() + relay, ok := s.udpRelays[svcID] + if ok { + delete(s.udpRelays, svcID) + } + s.udpMu.Unlock() + + if ok { + relay.Close() + s.Logger.WithField("service_id", svcID).Info("UDP relay removed") + } +} + +func (s *Server) storeMapping(mapping *proto.ProxyMapping) { + s.portMu.Lock() + s.lastMappings[types.ServiceID(mapping.GetId())] = mapping + s.portMu.Unlock() +} + +func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.RLock() + m := s.lastMappings[svcID] + s.portMu.RUnlock() + return m +} + +func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.Lock() + m := s.lastMappings[svcID] + delete(s.lastMappings, svcID) + s.portMu.Unlock() + return m +} + +func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping { + paths := make(map[string]*proxy.PathTarget) for _, pathMapping := range mapping.GetPath() { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { - // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), @@ -722,18 +1632,43 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { "path": pathMapping.GetPath(), "target": pathMapping.GetTarget(), }).WithError(err).Error("failed to parse target URL for path, skipping") + s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err)) continue } - paths[pathMapping.GetPath()] = targetURL + + pt := &proxy.PathTarget{URL: targetURL} + if opts := pathMapping.GetOptions(); opts != nil { + pt.SkipTLSVerify = opts.GetSkipTlsVerify() + pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite()) + pt.CustomHeaders = opts.GetCustomHeaders() + if d := opts.GetRequestTimeout(); d != nil { + pt.RequestTimeout = d.AsDuration() + } + } + pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout) + paths[pathMapping.GetPath()] = pt } - return proxy.Mapping{ - ID: mapping.GetId(), + m := proxy.Mapping{ + ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, PassHostHeader: mapping.GetPassHostHeader(), RewriteRedirects: mapping.GetRewriteRedirects(), } + for _, ha := range mapping.GetAuth().GetHeaderAuths() { + m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader()) + } + return m +} + +func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode { + switch mode { + case proto.PathRewriteMode_PATH_REWRITE_PRESERVE: + return proxy.PathRewritePreserve + default: + return proxy.PathRewriteDefault + } } // debugEndpointAddr returns the address for the debug endpoint. diff --git a/proxy/web/package-lock.json b/proxy/web/package-lock.json index d16196d77..1611323a7 100644 --- a/proxy/web/package-lock.json +++ b/proxy/web/package-lock.json @@ -15,7 +15,7 @@ "tailwind-merge": "^2.6.0" }, "devDependencies": { - "@eslint/js": "^9.39.1", + "@eslint/js": "9.39.2", "@tailwindcss/vite": "^4.1.18", "@types/node": "^24.10.1", "@types/react": "^19.2.5", @@ -29,7 +29,7 @@ "tsx": "^4.21.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "7.3.2" } }, "node_modules/@babel/code-frame": { @@ -1024,9 +1024,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.57.1.tgz", - "integrity": "sha512-A6ehUVSiSaaliTxai040ZpZ2zTevHYbvu/lDoeAteHI8QnaosIzm4qwtezfRg1jOYaUmnzLX1AOD6Z+UJjtifg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.0.tgz", + "integrity": "sha512-WOhNW9K8bR3kf4zLxbfg6Pxu2ybOUbB2AjMDHSQx86LIF4rH4Ft7vmMwNt0loO0eonglSNy4cpD3MKXXKQu0/A==", "cpu": [ "arm" ], @@ -1038,9 +1038,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.57.1.tgz", - "integrity": "sha512-dQaAddCY9YgkFHZcFNS/606Exo8vcLHwArFZ7vxXq4rigo2bb494/xKMMwRRQW6ug7Js6yXmBZhSBRuBvCCQ3w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.0.tgz", + "integrity": "sha512-u6JHLll5QKRvjciE78bQXDmqRqNs5M/3GVqZeMwvmjaNODJih/WIrJlFVEihvV0MiYFmd+ZyPr9wxOVbPAG2Iw==", "cpu": [ "arm64" ], @@ -1052,9 +1052,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.57.1.tgz", - "integrity": "sha512-crNPrwJOrRxagUYeMn/DZwqN88SDmwaJ8Cvi/TN1HnWBU7GwknckyosC2gd0IqYRsHDEnXf328o9/HC6OkPgOg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.0.tgz", + "integrity": "sha512-qEF7CsKKzSRc20Ciu2Zw1wRrBz4g56F7r/vRwY430UPp/nt1x21Q/fpJ9N5l47WWvJlkNCPJz3QRVw008fi7yA==", "cpu": [ "arm64" ], @@ -1066,9 +1066,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.57.1.tgz", - "integrity": "sha512-Ji8g8ChVbKrhFtig5QBV7iMaJrGtpHelkB3lsaKzadFBe58gmjfGXAOfI5FV0lYMH8wiqsxKQ1C9B0YTRXVy4w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.0.tgz", + "integrity": "sha512-WADYozJ4QCnXCH4wPB+3FuGmDPoFseVCUrANmA5LWwGmC6FL14BWC7pcq+FstOZv3baGX65tZ378uT6WG8ynTw==", "cpu": [ "x64" ], @@ -1080,9 +1080,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.57.1.tgz", - "integrity": "sha512-R+/WwhsjmwodAcz65guCGFRkMb4gKWTcIeLy60JJQbXrJ97BOXHxnkPFrP+YwFlaS0m+uWJTstrUA9o+UchFug==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.0.tgz", + "integrity": "sha512-6b8wGHJlDrGeSE3aH5mGNHBjA0TTkxdoNHik5EkvPHCt351XnigA4pS7Wsj/Eo9Y8RBU6f35cjN9SYmCFBtzxw==", "cpu": [ "arm64" ], @@ -1094,9 +1094,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.57.1.tgz", - "integrity": "sha512-IEQTCHeiTOnAUC3IDQdzRAGj3jOAYNr9kBguI7MQAAZK3caezRrg0GxAb6Hchg4lxdZEI5Oq3iov/w/hnFWY9Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.0.tgz", + "integrity": "sha512-h25Ga0t4jaylMB8M/JKAyrvvfxGRjnPQIR8lnCayyzEjEOx2EJIlIiMbhpWxDRKGKF8jbNH01NnN663dH638mA==", "cpu": [ "x64" ], @@ -1108,9 +1108,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.57.1.tgz", - "integrity": "sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.0.tgz", + "integrity": "sha512-RzeBwv0B3qtVBWtcuABtSuCzToo2IEAIQrcyB/b2zMvBWVbjo8bZDjACUpnaafaxhTw2W+imQbP2BD1usasK4g==", "cpu": [ "arm" ], @@ -1122,9 +1122,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.57.1.tgz", - "integrity": "sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.0.tgz", + "integrity": "sha512-Sf7zusNI2CIU1HLzuu9Tc5YGAHEZs5Lu7N1ssJG4Tkw6e0MEsN7NdjUDDfGNHy2IU+ENyWT+L2obgWiguWibWQ==", "cpu": [ "arm" ], @@ -1136,9 +1136,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.57.1.tgz", - "integrity": "sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.0.tgz", + "integrity": "sha512-DX2x7CMcrJzsE91q7/O02IJQ5/aLkVtYFryqCjduJhUfGKG6yJV8hxaw8pZa93lLEpPTP/ohdN4wFz7yp/ry9A==", "cpu": [ "arm64" ], @@ -1150,9 +1150,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.57.1.tgz", - "integrity": "sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.0.tgz", + "integrity": "sha512-09EL+yFVbJZlhcQfShpswwRZ0Rg+z/CsSELFCnPt3iK+iqwGsI4zht3secj5vLEs957QvFFXnzAT0FFPIxSrkQ==", "cpu": [ "arm64" ], @@ -1164,9 +1164,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.57.1.tgz", - "integrity": "sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.0.tgz", + "integrity": "sha512-i9IcCMPr3EXm8EQg5jnja0Zyc1iFxJjZWlb4wr7U2Wx/GrddOuEafxRdMPRYVaXjgbhvqalp6np07hN1w9kAKw==", "cpu": [ "loong64" ], @@ -1178,9 +1178,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.57.1.tgz", - "integrity": "sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.0.tgz", + "integrity": "sha512-DGzdJK9kyJ+B78MCkWeGnpXJ91tK/iKA6HwHxF4TAlPIY7GXEvMe8hBFRgdrR9Ly4qebR/7gfUs9y2IoaVEyog==", "cpu": [ "loong64" ], @@ -1192,9 +1192,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.57.1.tgz", - "integrity": "sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.0.tgz", + "integrity": "sha512-RwpnLsqC8qbS8z1H1AxBA1H6qknR4YpPR9w2XX0vo2Sz10miu57PkNcnHVaZkbqyw/kUWfKMI73jhmfi9BRMUQ==", "cpu": [ "ppc64" ], @@ -1206,9 +1206,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.57.1.tgz", - "integrity": "sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.0.tgz", + "integrity": "sha512-Z8pPf54Ly3aqtdWC3G4rFigZgNvd+qJlOE52fmko3KST9SoGfAdSRCwyoyG05q1HrrAblLbk1/PSIV+80/pxLg==", "cpu": [ "ppc64" ], @@ -1220,9 +1220,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.57.1.tgz", - "integrity": "sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.0.tgz", + "integrity": "sha512-3a3qQustp3COCGvnP4SvrMHnPQ9d1vzCakQVRTliaz8cIp/wULGjiGpbcqrkv0WrHTEp8bQD/B3HBjzujVWLOA==", "cpu": [ "riscv64" ], @@ -1234,9 +1234,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.57.1.tgz", - "integrity": "sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.0.tgz", + "integrity": "sha512-pjZDsVH/1VsghMJ2/kAaxt6dL0psT6ZexQVrijczOf+PeP2BUqTHYejk3l6TlPRydggINOeNRhvpLa0AYpCWSQ==", "cpu": [ "riscv64" ], @@ -1248,9 +1248,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.57.1.tgz", - "integrity": "sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.0.tgz", + "integrity": "sha512-3ObQs0BhvPgiUVZrN7gqCSvmFuMWvWvsjG5ayJ3Lraqv+2KhOsp+pUbigqbeWqueGIsnn+09HBw27rJ+gYK4VQ==", "cpu": [ "s390x" ], @@ -1262,9 +1262,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.57.1.tgz", - "integrity": "sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.0.tgz", + "integrity": "sha512-EtylprDtQPdS5rXvAayrNDYoJhIz1/vzN2fEubo3yLE7tfAw+948dO0g4M0vkTVFhKojnF+n6C8bDNe+gDRdTg==", "cpu": [ "x64" ], @@ -1276,9 +1276,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.57.1.tgz", - "integrity": "sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.0.tgz", + "integrity": "sha512-k09oiRCi/bHU9UVFqD17r3eJR9bn03TyKraCrlz5ULFJGdJGi7VOmm9jl44vOJvRJ6P7WuBi/s2A97LxxHGIdw==", "cpu": [ "x64" ], @@ -1290,9 +1290,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.57.1.tgz", - "integrity": "sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.0.tgz", + "integrity": "sha512-1o/0/pIhozoSaDJoDcec+IVLbnRtQmHwPV730+AOD29lHEEo4F5BEUB24H0OBdhbBBDwIOSuf7vgg0Ywxdfiiw==", "cpu": [ "x64" ], @@ -1304,9 +1304,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.57.1.tgz", - "integrity": "sha512-4wYoDpNg6o/oPximyc/NG+mYUejZrCU2q+2w6YZqrAs2UcNUChIZXjtafAiiZSUc7On8v5NyNj34Kzj/Ltk6dQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.0.tgz", + "integrity": "sha512-pESDkos/PDzYwtyzB5p/UoNU/8fJo68vcXM9ZW2V0kjYayj1KaaUfi1NmTUTUpMn4UhU4gTuK8gIaFO4UGuMbA==", "cpu": [ "arm64" ], @@ -1318,9 +1318,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.57.1.tgz", - "integrity": "sha512-O54mtsV/6LW3P8qdTcamQmuC990HDfR71lo44oZMZlXU4tzLrbvTii87Ni9opq60ds0YzuAlEr/GNwuNluZyMQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.0.tgz", + "integrity": "sha512-hj1wFStD7B1YBeYmvY+lWXZ7ey73YGPcViMShYikqKT1GtstIKQAtfUI6yrzPjAy/O7pO0VLXGmUVWXQMaYgTQ==", "cpu": [ "arm64" ], @@ -1332,9 +1332,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.57.1.tgz", - "integrity": "sha512-P3dLS+IerxCT/7D2q2FYcRdWRl22dNbrbBEtxdWhXrfIMPP9lQhb5h4Du04mdl5Woq05jVCDPCMF7Ub0NAjIew==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.0.tgz", + "integrity": "sha512-SyaIPFoxmUPlNDq5EHkTbiKzmSEmq/gOYFI/3HHJ8iS/v1mbugVa7dXUzcJGQfoytp9DJFLhHH4U3/eTy2Bq4w==", "cpu": [ "ia32" ], @@ -1346,9 +1346,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.57.1.tgz", - "integrity": "sha512-VMBH2eOOaKGtIJYleXsi2B8CPVADrh+TyNxJ4mWPnKfLB/DBUmzW+5m1xUrcwWoMfSLagIRpjUFeW5CO5hyciQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.0.tgz", + "integrity": "sha512-RdcryEfzZr+lAr5kRm2ucN9aVlCCa2QNq4hXelZxb8GG0NJSazq44Z3PCCc8wISRuCVnGs0lQJVX5Vp6fKA+IA==", "cpu": [ "x64" ], @@ -1360,9 +1360,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.57.1.tgz", - "integrity": "sha512-mxRFDdHIWRxg3UfIIAwCm6NzvxG0jDX/wBN6KsQFTvKFqqg9vTrWUE68qEjHt19A5wwx5X5aUi2zuZT7YR0jrA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.0.tgz", + "integrity": "sha512-PrsWNQ8BuE00O3Xsx3ALh2Df8fAj9+cvvX9AIA6o4KpATR98c9mud4XtDWVvsEuyia5U4tVSTKygawyJkjm60w==", "cpu": [ "x64" ], @@ -1926,9 +1926,9 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", + "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", "dev": true, "license": "MIT", "dependencies": { @@ -1936,13 +1936,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "version": "9.0.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", + "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.1" + "brace-expansion": "^2.0.2" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -2052,9 +2052,9 @@ } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", + "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", "dev": true, "license": "MIT", "dependencies": { @@ -2109,9 +2109,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "dev": true, "license": "MIT", "dependencies": { @@ -2657,9 +2657,9 @@ } }, "node_modules/flatted": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", - "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", + "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", "dev": true, "license": "ISC" }, @@ -3243,9 +3243,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -3386,9 +3386,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", - "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", + "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", "peer": true, @@ -3501,9 +3501,9 @@ } }, "node_modules/rollup": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.57.1.tgz", - "integrity": "sha512-oQL6lgK3e2QZeQ7gcgIkS2YZPg5slw37hYufJ3edKlfQSGGm8ICoxswK15ntSzF/a8+h7ekRy7k7oWc3BQ7y8A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.0.tgz", + "integrity": "sha512-yqjxruMGBQJ2gG4HtjZtAfXArHomazDHoFwFFmZZl0r7Pdo7qCIXKqKHZc8yeoMgzJJ+pO6pEEHa+V7uzWlrAQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3517,31 +3517,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.57.1", - "@rollup/rollup-android-arm64": "4.57.1", - "@rollup/rollup-darwin-arm64": "4.57.1", - "@rollup/rollup-darwin-x64": "4.57.1", - "@rollup/rollup-freebsd-arm64": "4.57.1", - "@rollup/rollup-freebsd-x64": "4.57.1", - "@rollup/rollup-linux-arm-gnueabihf": "4.57.1", - "@rollup/rollup-linux-arm-musleabihf": "4.57.1", - "@rollup/rollup-linux-arm64-gnu": "4.57.1", - "@rollup/rollup-linux-arm64-musl": "4.57.1", - "@rollup/rollup-linux-loong64-gnu": "4.57.1", - "@rollup/rollup-linux-loong64-musl": "4.57.1", - "@rollup/rollup-linux-ppc64-gnu": "4.57.1", - "@rollup/rollup-linux-ppc64-musl": "4.57.1", - "@rollup/rollup-linux-riscv64-gnu": "4.57.1", - "@rollup/rollup-linux-riscv64-musl": "4.57.1", - "@rollup/rollup-linux-s390x-gnu": "4.57.1", - "@rollup/rollup-linux-x64-gnu": "4.57.1", - "@rollup/rollup-linux-x64-musl": "4.57.1", - "@rollup/rollup-openbsd-x64": "4.57.1", - "@rollup/rollup-openharmony-arm64": "4.57.1", - "@rollup/rollup-win32-arm64-msvc": "4.57.1", - "@rollup/rollup-win32-ia32-msvc": "4.57.1", - "@rollup/rollup-win32-x64-gnu": "4.57.1", - "@rollup/rollup-win32-x64-msvc": "4.57.1", + "@rollup/rollup-android-arm-eabi": "4.60.0", + "@rollup/rollup-android-arm64": "4.60.0", + "@rollup/rollup-darwin-arm64": "4.60.0", + "@rollup/rollup-darwin-x64": "4.60.0", + "@rollup/rollup-freebsd-arm64": "4.60.0", + "@rollup/rollup-freebsd-x64": "4.60.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.60.0", + "@rollup/rollup-linux-arm-musleabihf": "4.60.0", + "@rollup/rollup-linux-arm64-gnu": "4.60.0", + "@rollup/rollup-linux-arm64-musl": "4.60.0", + "@rollup/rollup-linux-loong64-gnu": "4.60.0", + "@rollup/rollup-linux-loong64-musl": "4.60.0", + "@rollup/rollup-linux-ppc64-gnu": "4.60.0", + "@rollup/rollup-linux-ppc64-musl": "4.60.0", + "@rollup/rollup-linux-riscv64-gnu": "4.60.0", + "@rollup/rollup-linux-riscv64-musl": "4.60.0", + "@rollup/rollup-linux-s390x-gnu": "4.60.0", + "@rollup/rollup-linux-x64-gnu": "4.60.0", + "@rollup/rollup-linux-x64-musl": "4.60.0", + "@rollup/rollup-openbsd-x64": "4.60.0", + "@rollup/rollup-openharmony-arm64": "4.60.0", + "@rollup/rollup-win32-arm64-msvc": "4.60.0", + "@rollup/rollup-win32-ia32-msvc": "4.60.0", + "@rollup/rollup-win32-x64-gnu": "4.60.0", + "@rollup/rollup-win32-x64-msvc": "4.60.0", "fsevents": "~2.3.2" } }, @@ -3803,9 +3803,9 @@ } }, "node_modules/vite": { - "version": "7.3.1", - "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz", - "integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==", + "version": "7.3.2", + "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz", + "integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==", "dev": true, "license": "MIT", "peer": true, diff --git a/proxy/web/package.json b/proxy/web/package.json index 97ec1ec0d..9a7c84ed4 100644 --- a/proxy/web/package.json +++ b/proxy/web/package.json @@ -17,7 +17,7 @@ "tailwind-merge": "^2.6.0" }, "devDependencies": { - "@eslint/js": "^9.39.1", + "@eslint/js": "9.39.2", "@tailwindcss/vite": "^4.1.18", "@types/node": "^24.10.1", "@types/react": "^19.2.5", @@ -31,6 +31,6 @@ "tsx": "^4.21.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "7.3.2" } } diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 8c3ee1899..067888406 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -1,11 +1,13 @@ package server import ( + "context" "fmt" - "net" + "time" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay/messages" //nolint:staticcheck "github.com/netbirdio/netbird/shared/relay/messages/address" @@ -13,6 +15,12 @@ import ( authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth" ) +const ( + // handshakeTimeout bounds how long a connection may remain in the + // pre-authentication handshake phase before being closed. + handshakeTimeout = 10 * time.Second +) + type Validator interface { Validate(any) error // Deprecated: Use Validate instead. @@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { } type handshake struct { - conn net.Conn + conn listener.Conn validator Validator preparedMsg *preparedMsg @@ -66,9 +74,9 @@ type handshake struct { peerID *messages.PeerID } -func (h *handshake) handshakeReceive() (*messages.PeerID, error) { +func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) { buf := make([]byte, messages.MaxHandshakeSize) - n, err := h.conn.Read(buf) + n, err := h.conn.Read(ctx, buf) if err != nil { return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) } @@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return peerID, nil } -func (h *handshake) handshakeResponse() error { +func (h *handshake) handshakeResponse(ctx context.Context) error { var responseMsg []byte if h.handshakeMethodAuth { responseMsg = h.preparedMsg.responseAuthMsg @@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error { responseMsg = h.preparedMsg.responseHelloMsg } - if _, err := h.conn.Write(responseMsg); err != nil { + if _, err := h.conn.Write(ctx, responseMsg); err != nil { return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) } diff --git a/relay/server/listener/conn.go b/relay/server/listener/conn.go new file mode 100644 index 000000000..ef0869594 --- /dev/null +++ b/relay/server/listener/conn.go @@ -0,0 +1,14 @@ +package listener + +import ( + "context" + "net" +) + +// Conn is the relay connection contract implemented by WS and QUIC transports. +type Conn interface { + Read(ctx context.Context, b []byte) (n int, err error) + Write(ctx context.Context, b []byte) (n int, err error) + RemoteAddr() net.Addr + Close() error +} diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go deleted file mode 100644 index 0a79182f4..000000000 --- a/relay/server/listener/listener.go +++ /dev/null @@ -1,14 +0,0 @@ -package listener - -import ( - "context" - "net" - - "github.com/netbirdio/netbird/relay/protocol" -) - -type Listener interface { - Listen(func(conn net.Conn)) error - Shutdown(ctx context.Context) error - Protocol() protocol.Protocol -} diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go index 6e2201bf7..d8dafcd1f 100644 --- a/relay/server/listener/quic/conn.go +++ b/relay/server/listener/quic/conn.go @@ -3,33 +3,26 @@ package quic import ( "context" "errors" - "fmt" "net" "sync" - "time" "github.com/quic-go/quic-go" ) type Conn struct { - session *quic.Conn - closed bool - closedMu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc + session *quic.Conn + closed bool + closedMu sync.Mutex } func NewConn(session *quic.Conn) *Conn { - ctx, cancel := context.WithCancel(context.Background()) return &Conn{ - session: session, - ctx: ctx, - ctxCancel: cancel, + session: session, } } -func (c *Conn) Read(b []byte) (n int, err error) { - dgram, err := c.session.ReceiveDatagram(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(ctx) if err != nil { return 0, c.remoteCloseErrHandling(err) } @@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) { return n, nil } -func (c *Conn) Write(b []byte) (int, error) { +func (c *Conn) Write(_ context.Context, b []byte) (int, error) { if err := c.session.SendDatagram(b); err != nil { return 0, c.remoteCloseErrHandling(err) } return len(b), nil } -func (c *Conn) LocalAddr() net.Addr { - return c.session.LocalAddr() -} - func (c *Conn) RemoteAddr() net.Addr { return c.session.RemoteAddr() } -func (c *Conn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() if c.closed { @@ -74,8 +51,6 @@ func (c *Conn) Close() error { c.closed = true c.closedMu.Unlock() - c.ctxCancel() // Cancel the context - sessionErr := c.session.CloseWithError(0, "normal closure") return sessionErr } diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 797223e74..68f0e03c0 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -5,12 +5,12 @@ import ( "crypto/tls" "errors" "fmt" - "net" "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" nbRelay "github.com/netbirdio/netbird/shared/relay" ) @@ -25,7 +25,7 @@ type Listener struct { listener *quic.Listener } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { quicCfg := &quic.Config{ EnableDatagrams: true, InitialPacketSize: nbRelay.QUICInitialPacketSize, diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index d5bce56f7..c22b5719d 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -18,25 +18,21 @@ const ( type Conn struct { *websocket.Conn - lAddr *net.TCPAddr rAddr *net.TCPAddr closed bool closedMu sync.Mutex - ctx context.Context } -func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { +func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn { return &Conn{ Conn: wsConn, - lAddr: lAddr, rAddr: rAddr, - ctx: context.Background(), } } -func (c *Conn) Read(b []byte) (n int, err error) { - t, r, err := c.Reader(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + t, r, err := c.Reader(ctx) if err != nil { return 0, c.ioErrHandling(err) } @@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Write writes a binary message with the given payload. // It does not block until fill the internal buffer. // If the buffer filled up, wait until the buffer is drained or timeout. -func (c *Conn) Write(b []byte) (int, error) { - ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) +func (c *Conn) Write(ctx context.Context, b []byte) (int, error) { + ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout) defer ctxCancel() err := c.Conn.Write(ctx, websocket.MessageBinary, b) return len(b), err } -func (c *Conn) LocalAddr() net.Addr { - return c.lAddr -} - func (c *Conn) RemoteAddr() net.Addr { return c.rAddr } -func (c *Conn) SetReadDeadline(t time.Time) error { - return fmt.Errorf("SetReadDeadline is not implemented") -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() c.closed = true diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 12219e29b..ba175f901 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -7,11 +7,13 @@ import ( "fmt" "net" "net/http" + "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay" ) @@ -27,18 +29,19 @@ type Listener struct { TLSConfig *tls.Config server *http.Server - acceptFn func(conn net.Conn) + acceptFn func(conn relaylistener.Conn) } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { l.acceptFn = acceptFn mux := http.NewServeMux() mux.HandleFunc(URLPath, l.onAccept) l.server = &http.Server{ - Addr: l.Address, - Handler: mux, - TLSConfig: l.TLSConfig, + Addr: l.Address, + Handler: mux, + TLSConfig: l.TLSConfig, + ReadHeaderTimeout: 5 * time.Second, } log.Infof("WS server listening address: %s", l.Address) @@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } - lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr) - if err != nil { - err = wsConn.Close(websocket.StatusInternalError, "internal error") - if err != nil { - log.Errorf("failed to close ws connection: %s", err) - } - return - } - log.Infof("WS client connected from: %s", rAddr) - conn := NewConn(wsConn, lAddr, rAddr) + conn := NewConn(wsConn, rAddr) l.acceptFn(conn) } diff --git a/relay/server/peer.go b/relay/server/peer.go index c5ff41857..8376cdfa7 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" @@ -26,11 +27,14 @@ type Peer struct { metrics *metrics.Metrics log *log.Entry id messages.PeerID - conn net.Conn + conn listener.Conn connMu sync.RWMutex store *store.Store notifier *store.PeerNotifier + ctx context.Context + ctxCancel context.CancelFunc + peersListener *store.Listener // between the online peer collection step and the notification sending should not be sent offline notifications from another thread @@ -38,14 +42,17 @@ type Peer struct { } // NewPeer creates a new Peer instance and prepare custom logging -func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { +func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { + ctx, cancel := context.WithCancel(context.Background()) p := &Peer{ - metrics: metrics, - log: log.WithField("peer_id", id.String()), - id: id, - conn: conn, - store: store, - notifier: notifier, + metrics: metrics, + log: log.WithField("peer_id", id.String()), + id: id, + conn: conn, + store: store, + notifier: notifier, + ctx: ctx, + ctxCancel: cancel, } return p @@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store func (p *Peer) Work() { p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) defer func() { + p.ctxCancel() p.notifier.RemoveListener(p.peersListener) if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { @@ -64,8 +72,7 @@ func (p *Peer) Work() { } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := p.ctx hc := healthcheck.NewSender(p.log) go hc.StartHealthCheck(ctx) @@ -73,7 +80,7 @@ func (p *Peer) Work() { buf := make([]byte, bufferSize) for { - n, err := p.conn.Read(buf) + n, err := p.conn.Read(ctx, buf) if err != nil { if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) @@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * } // Write writes data to the connection -func (p *Peer) Write(b []byte) (int, error) { +func (p *Peer) Write(ctx context.Context, b []byte) (int, error) { p.connMu.RLock() defer p.connMu.RUnlock() - return p.conn.Write(b) + return p.conn.Write(ctx, b) } // CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the @@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -156,6 +164,7 @@ func (p *Peer) Close() { p.connMu.Lock() defer p.connMu.Unlock() + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - writeDone := make(chan struct{}) - var err error - go func() { - _, err = p.conn.Write(buf) - close(writeDone) - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-writeDone: - return err - } + _, err := p.conn.Write(ctx, buf) + return err } func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { for { select { case <-hc.HealthCheck: - _, err := p.Write(messages.MarshalHealthcheck()) + _, err := p.Write(ctx, messages.MarshalHealthcheck()) if err != nil { p.log.Errorf("failed to send healthcheck message: %s", err) return @@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - n, err := dp.Write(msg) + n, err := dp.Write(dp.ctx, msg) if err != nil { p.log.Errorf("failed to write transport message to: %s", dp.String()) return } - p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) + p.metrics.TransferBytesSent.Add(p.ctx, int64(n)) } func (p *Peer) handleSubscribePeerState(msg []byte) { @@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } @@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } diff --git a/relay/server/relay.go b/relay/server/relay.go index bb355f58f..56add8bea 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "net" "net/url" "sync" "time" @@ -13,11 +12,20 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/healthcheck/peerid" + "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server/listener" + //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" ) +type Listener interface { + Listen(func(conn listener.Conn)) error + Shutdown(ctx context.Context) error + Protocol() protocol.Protocol +} + type Config struct { Meter metric.Meter ExposedAddress string @@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) { } // Accept start to handle a new peer connection -func (r *Relay) Accept(conn net.Conn) { +func (r *Relay) Accept(conn listener.Conn) { acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() @@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) { return } + hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout) + defer hsCancel() + h := handshake{ conn: conn, validator: r.validator, preparedMsg: r.preparedMsg, } - peerID, err := h.handshakeReceive() + peerID, err := h.handshakeReceive(hsCtx) if err != nil { if peerid.IsHealthCheck(peerID) { log.Debugf("health check connection from %s", conn.RemoteAddr()) @@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) { r.metrics.PeerDisconnected(peer.String()) }() - if err := h.handshakeResponse(); err != nil { + if err := h.handshakeResponse(hsCtx); err != nil { log.Errorf("failed to send handshake response, close peer: %s", err) peer.Close() } diff --git a/relay/server/server.go b/relay/server/server.go index a0f7eb73c..340da55b8 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/tls" - "net" "net/url" "sync" @@ -31,7 +30,7 @@ type ListenerConfig struct { // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { relay *Relay - listeners []listener.Listener + listeners []Listener listenerMux sync.Mutex } @@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) { } return &Server{ relay: relay, - listeners: make([]listener.Listener, 0, 2), + listeners: make([]Listener, 0, 2), }, nil } @@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error { wg := sync.WaitGroup{} for _, l := range r.listeners { wg.Add(1) - go func(listener listener.Listener) { + go func(listener Listener) { defer wg.Done() errChan <- listener.Listen(r.relay.Accept) }(l) @@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL { // RelayAccept returns the relay's Accept function for handling incoming connections. // This allows external HTTP handlers to route connections to the relay without // starting the relay's own listeners. -func (r *Server) RelayAccept() func(conn net.Conn) { +func (r *Server) RelayAccept() func(conn listener.Conn) { return r.relay.Accept } diff --git a/release_files/install.sh b/release_files/install.sh index 6a2c5f458..1e71936f3 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -128,7 +128,7 @@ cat <<-EOF | ${SUDO} tee /etc/yum.repos.d/netbird.repo name=NetBird baseurl=https://pkgs.netbird.io/yum/ enabled=1 -gpgcheck=0 +gpgcheck=1 gpgkey=https://pkgs.netbird.io/yum/repodata/repomd.xml.key repo_gpgcheck=1 EOF diff --git a/shared/auth/jwt/validator.go b/shared/auth/jwt/validator.go index aeaa5842c..cf18b2cf6 100644 --- a/shared/auth/jwt/validator.go +++ b/shared/auth/jwt/validator.go @@ -25,7 +25,7 @@ import ( // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` - expiresInTime time.Time + ExpiresInTime time.Time `json:"-"` } // The supported elliptic curves types @@ -53,12 +53,17 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage) +// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys. +type KeyFetcher func(ctx context.Context) (*Jwks, error) + type Validator struct { lock sync.Mutex issuer string audienceList []string keysLocation string idpSignkeyRefreshEnabled bool + keyFetcher KeyFetcher keys *Jwks lastForcedRefresh time.Time } @@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp } } +// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the +// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP. +func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator { + ctx := context.Background() + keys, err := keyFetcher(ctx) + if err != nil { + log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err) + } + if keys == nil { + keys = &Jwks{} + } + + return &Validator{ + keys: keys, + issuer: issuer, + audienceList: audienceList, + idpSignkeyRefreshEnabled: true, + keyFetcher: keyFetcher, + } +} + // forcedRefreshCooldown is the minimum time between forced key refreshes // to prevent abuse from invalid tokens with fake kid values const forcedRefreshCooldown = 30 * time.Second +// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP. +func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) { + if v.keyFetcher != nil { + return v.keyFetcher(ctx) + } + return getPemKeys(v.keysLocation) +} + func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { // If keys are rotated, verify the keys prior to token validation @@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) { v.lock.Lock() defer v.lock.Unlock() - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys } @@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool { log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh") - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return false } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys v.lastForcedRefresh = time.Now() return true @@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To // stillValid returns true if the JSONWebKey still valid and have enough time to be used func (jwks *Jwks) stillValid() bool { - return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) + return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime) } func getPemKeys(keysLocation string) (*Jwks, error) { @@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) { cacheControlHeader := resp.Header.Get("Cache-Control") expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) - jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) return jwks, nil } diff --git a/shared/management/client/client.go b/shared/management/client/client.go index b92c636c5..18efba87b 100644 --- a/shared/management/client/client.go +++ b/shared/management/client/client.go @@ -4,24 +4,31 @@ import ( "context" "io" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) +// Client is the interface for the management service client. type Client interface { io.Closer Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error - GetServerPublicKey() (*wgtypes.Key, error) - Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) + Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) + GetServerURL() string + // IsHealthy returns the current connection status without blocking. + // Used by the engine to monitor connectivity in the background. IsHealthy() bool + // HealthCheck actively probes the management server and returns an error if unreachable. + // Used to validate connectivity before committing configuration changes. + HealthCheck() error SyncMeta(sysInfo *system.Info) error Logout() error + CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) + RenewExpose(ctx context.Context, domain string) error + StopExpose(ctx context.Context, domain string) error } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index a11f863a7..d9a1a7d65 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" @@ -95,9 +96,16 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManger) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -116,11 +124,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -189,7 +196,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) { } } -func TestClient_GetServerPublicKey(t *testing.T) { +func TestClient_HealthCheck(t *testing.T) { testKey, err := wgtypes.GenerateKey() if err != nil { t.Fatal(err) @@ -203,12 +210,8 @@ func TestClient_GetServerPublicKey(t *testing.T) { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Error("couldn't retrieve management public key") - } - if key == nil { - t.Error("got an empty management public key") + if err := client.HealthCheck(); err != nil { + t.Errorf("health check failed: %v", err) } } @@ -225,12 +228,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) { if err != nil { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Fatal(err) - } sysInfo := system.GetInfo(context.TODO()) - _, err = client.Login(*key, sysInfo, nil, nil) + _, err = client.Login(sysInfo, nil, nil) if err == nil { t.Error("expecting err on unregistered login, got nil") } @@ -253,12 +252,8 @@ func TestClient_LoginRegistered(t *testing.T) { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } info := system.GetInfo(context.TODO()) - resp, err := client.Register(*key, ValidKey, "", info, nil, nil) + resp, err := client.Register(ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -282,13 +277,8 @@ func TestClient_Sync(t *testing.T) { t.Fatal(err) } - serverKey, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } - info := system.GetInfo(context.TODO()) - _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) + _, err = client.Register(ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -304,7 +294,7 @@ func TestClient_Sync(t *testing.T) { } info = system.GetInfo(context.TODO()) - _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) + _, err = remoteClient.Register(ValidKey, "", info, nil, nil) if err != nil { t.Fatal(err) } @@ -364,11 +354,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) { t.Fatalf("error while creating testClient: %v", err) } - key, err := testClient.GetServerPublicKey() - if err != nil { - t.Fatalf("error while getting server public key from testclient, %v", err) - } - var actualMeta *mgmtProto.PeerSystemMeta var actualValidKey string var wg sync.WaitGroup @@ -405,7 +390,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = testClient.Register(*key, ValidKey, "", info, nil, nil) + _, err = testClient.Register(ValidKey, "", info, nil, nil) if err != nil { t.Errorf("error while trying to register client: %v", err) } @@ -505,7 +490,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { } mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { - encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) + encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo) if err != nil { return nil, err } @@ -517,7 +502,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey) + flowInfo, err := client.GetDeviceAuthorizationFlow() if err != nil { t.Error("error while retrieving device auth flow information") } @@ -551,7 +536,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { } mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { - encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) + encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo) if err != nil { return nil, err } @@ -563,11 +548,11 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey) + flowInfo, err := client.GetPKCEAuthorizationFlow() if err != nil { t.Error("error while retrieving pkce auth flow information") } assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") - assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") + assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") //nolint:staticcheck } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index d54c8f870..a01e51abc 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + "os" + "strconv" "sync" "time" @@ -29,6 +31,10 @@ import ( const ConnectTimeout = 10 * time.Second const ( + // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) + // for the management client connection. Value is in bytes. + EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE" + errMsgMgmtPublicKey = "failed getting Management Service public key: %s" errMsgNoMgmtConnection = "no connection to management" ) @@ -46,15 +52,62 @@ type GrpcClient struct { conn *grpc.ClientConn connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + serverURL string +} + +type ExposeRequest struct { + NamePrefix string + Domain string + Port uint16 + Protocol int + Pin string + Password string + UserGroups []string + ListenPort uint16 +} + +type ExposeResponse struct { + ServiceName string + Domain string + ServiceURL string + PortAutoAssigned bool +} + +// MaxRecvMsgSize returns the configured max gRPC receive message size from +// the environment, or 0 if unset (which uses the gRPC default of 4 MB). +func MaxRecvMsgSize() int { + val := os.Getenv(EnvMaxRecvMsgSize) + if val == "" { + return 0 + } + + size, err := strconv.Atoi(val) + if err != nil { + log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err) + return 0 + } + + if size <= 0 { + log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size) + return 0 + } + + return size } // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { var conn *grpc.ClientConn + var extraOpts []grpc.DialOption + if maxSize := MaxRecvMsgSize(); maxSize > 0 { + extraOpts = append(extraOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxSize))) + log.Infof("management gRPC max receive message size set to %d bytes", maxSize) + } + operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent, extraOpts...) if err != nil { return fmt.Errorf("create connection: %w", err) } @@ -75,9 +128,15 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE ctx: ctx, conn: conn, connStateCallbackLock: sync.RWMutex{}, + serverURL: addr, }, nil } +// GetServerURL returns the management server URL +func (c *GrpcClient) GetServerURL() string { + return c.serverURL +} + // Close closes connection to the Management Service func (c *GrpcClient) Close() error { return c.conn.Close() @@ -143,7 +202,7 @@ func (c *GrpcClient) withMgmtStream( return fmt.Errorf("connection to management is not ready and in %s state", connState) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -345,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. // GetNetworkMap return with the network map func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) { - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf("failed getting Management Service public key: %s", err) return nil, err @@ -431,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli } } -// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) -func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { +// HealthCheck actively probes the management server and returns an error if unreachable. +// Used to validate connectivity before committing configuration changes. +func (c *GrpcClient) HealthCheck() error { if !c.ready() { - return nil, errors.New(errMsgNoMgmtConnection) + return errors.New(errMsgNoMgmtConnection) } + _, err := c.getServerPublicKey() + return err +} + +// getServerPublicKey fetches the server's WireGuard public key. +func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) { mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) defer cancel() resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, fmt.Errorf("failed while getting Management Service public key") + return nil, fmt.Errorf("failed getting Management Service public key: %w", err) } serverKey, err := wgtypes.ParseKey(resp.Key) @@ -453,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { return &serverKey, nil } -// IsHealthy probes the gRPC connection and returns false on errors +// IsHealthy returns the current connection status without blocking. +// Used by the engine to monitor connectivity in the background. func (c *GrpcClient) IsHealthy() bool { switch c.conn.GetState() { case connectivity.TransientFailure: @@ -479,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool { return true } -func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { +func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) { if !c.ready() { return nil, errors.New(errMsgNoMgmtConnection) } - loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + + loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err @@ -518,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro } loginResp := &proto.LoginResponse{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp) if err != nil { log.Errorf("failed to decrypt login response: %s", err) return nil, err @@ -530,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. -func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // GetDeviceAuthorizationFlow returns a device authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { +func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get device authorization flow") } + + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.DeviceAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) if err != nil { return nil, err } @@ -571,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D } flowInfoResp := &proto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) log.Error(errWithMSG) @@ -583,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D // GetPKCEAuthorizationFlow returns a pkce authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { +func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow") } + + serverKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.PKCEAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) if err != nil { return nil, err } @@ -605,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC } flowInfoResp := &proto.PKCEAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) log.Error(errWithMSG) @@ -622,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error { return errors.New(errMsgNoMgmtConnection) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey() if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -665,7 +748,7 @@ func (c *GrpcClient) notifyConnected() { } func (c *GrpcClient) Logout() error { - serverKey, err := c.GetServerPublicKey() + serverKey, err := c.getServerPublicKey() if err != nil { return fmt.Errorf("get server public key: %w", err) } @@ -690,6 +773,127 @@ func (c *GrpcClient) Logout() error { return nil } +// CreateExpose calls the management server to create a new expose service. +func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) { + serverPubKey, err := c.getServerPublicKey() + if err != nil { + return nil, err + } + + protoReq, err := toProtoExposeServiceRequest(req) + if err != nil { + return nil, err + } + + encReq, err := encryption.EncryptMessage(*serverPubKey, c.key, protoReq) + if err != nil { + return nil, fmt.Errorf("encrypt create expose request: %w", err) + } + + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) + defer cancel() + + resp, err := c.realClient.CreateExpose(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encReq, + }) + if err != nil { + return nil, err + } + + exposeResp := &proto.ExposeServiceResponse{} + if err := encryption.DecryptMessage(*serverPubKey, c.key, resp.Body, exposeResp); err != nil { + return nil, fmt.Errorf("decrypt create expose response: %w", err) + } + + return fromProtoExposeResponse(exposeResp), nil +} + +// RenewExpose extends the TTL of an active expose session on the management server. +func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error { + serverPubKey, err := c.getServerPublicKey() + if err != nil { + return err + } + + req := &proto.RenewExposeRequest{Domain: domain} + encReq, err := encryption.EncryptMessage(*serverPubKey, c.key, req) + if err != nil { + return fmt.Errorf("encrypt renew expose request: %w", err) + } + + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) + defer cancel() + + _, err = c.realClient.RenewExpose(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encReq, + }) + return err +} + +// StopExpose terminates an active expose session on the management server. +func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { + serverPubKey, err := c.getServerPublicKey() + if err != nil { + return err + } + + req := &proto.StopExposeRequest{Domain: domain} + encReq, err := encryption.EncryptMessage(*serverPubKey, c.key, req) + if err != nil { + return fmt.Errorf("encrypt stop expose request: %w", err) + } + + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) + defer cancel() + + _, err = c.realClient.StopExpose(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encReq, + }) + return err +} + +func fromProtoExposeResponse(resp *proto.ExposeServiceResponse) *ExposeResponse { + return &ExposeResponse{ + ServiceName: resp.ServiceName, + Domain: resp.Domain, + ServiceURL: resp.ServiceUrl, + PortAutoAssigned: resp.PortAutoAssigned, + } +} + +func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest, error) { + var protocol proto.ExposeProtocol + + switch req.Protocol { + case int(proto.ExposeProtocol_EXPOSE_HTTP): + protocol = proto.ExposeProtocol_EXPOSE_HTTP + case int(proto.ExposeProtocol_EXPOSE_HTTPS): + protocol = proto.ExposeProtocol_EXPOSE_HTTPS + case int(proto.ExposeProtocol_EXPOSE_TCP): + protocol = proto.ExposeProtocol_EXPOSE_TCP + case int(proto.ExposeProtocol_EXPOSE_UDP): + protocol = proto.ExposeProtocol_EXPOSE_UDP + case int(proto.ExposeProtocol_EXPOSE_TLS): + protocol = proto.ExposeProtocol_EXPOSE_TLS + default: + return nil, fmt.Errorf("invalid expose protocol: %d", req.Protocol) + } + + return &proto.ExposeServiceRequest{ + NamePrefix: req.NamePrefix, + Domain: req.Domain, + Port: uint32(req.Port), + Protocol: protocol, + Pin: req.Pin, + Password: req.Password, + UserGroups: req.UserGroups, + ListenPort: uint32(req.ListenPort), + }, nil +} + func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil diff --git a/shared/management/client/grpc_test.go b/shared/management/client/grpc_test.go new file mode 100644 index 000000000..462cc43af --- /dev/null +++ b/shared/management/client/grpc_test.go @@ -0,0 +1,95 @@ +package client + +import ( + "context" + "net" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestMaxRecvMsgSize(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {name: "unset returns 0", envValue: "", expected: 0}, + {name: "valid value", envValue: "10485760", expected: 10485760}, + {name: "non-numeric returns 0", envValue: "abc", expected: 0}, + {name: "negative returns 0", envValue: "-1", expected: 0}, + {name: "zero returns 0", envValue: "0", expected: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(EnvMaxRecvMsgSize, tt.envValue) + if tt.envValue == "" { + os.Unsetenv(EnvMaxRecvMsgSize) + } + assert.Equal(t, tt.expected, MaxRecvMsgSize()) + }) + } +} + +// largeSyncServer implements just the Sync RPC, returning a response larger than the default 4MB limit. +type largeSyncServer struct { + mgmtProto.UnimplementedManagementServiceServer + responseSize int +} + +func (s *largeSyncServer) GetServerKey(_ context.Context, _ *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) { + // Return a response with a large WiretrusteeConfig to exceed the default limit. + padding := strings.Repeat("x", s.responseSize) + return &mgmtProto.ServerKeyResponse{ + Key: padding, + }, nil +} + +func TestMaxRecvMsgSizeIntegration(t *testing.T) { + const payloadSize = 5 * 1024 * 1024 // 5MB, exceeds 4MB default + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := grpc.NewServer() + mgmtProto.RegisterManagementServiceServer(srv, &largeSyncServer{responseSize: payloadSize}) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + t.Run("default limit rejects large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + _, err = client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "received message larger than max") + }) + + t.Run("increased limit accepts large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(10*1024*1024)), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + resp, err := client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.NoError(t, err) + assert.Len(t, resp.Key, payloadSize) + }) +} diff --git a/shared/management/client/mock.go b/shared/management/client/mock.go index ac96f7b36..361e8ffad 100644 --- a/shared/management/client/mock.go +++ b/shared/management/client/mock.go @@ -3,24 +3,27 @@ package client import ( "context" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) +// MockClient is a mock implementation of the Client interface for testing. type MockClient struct { CloseFunc func() error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error - GetServerPublicKeyFunc func() (*wgtypes.Key, error) - RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) + RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error) + GetServerURLFunc func() string + HealthCheckFunc func() error SyncMetaFunc func(sysInfo *system.Info) error LogoutFunc func() error JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error + CreateExposeFunc func(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) + RenewExposeFunc func(ctx context.Context, domain string) error + StopExposeFunc func(ctx context.Context, domain string) error } func (m *MockClient) IsHealthy() bool { @@ -48,46 +51,54 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ return m.JobFunc(ctx, msgHandler) } -func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { - if m.GetServerPublicKeyFunc == nil { - return nil, nil - } - return m.GetServerPublicKeyFunc() -} - -func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.RegisterFunc == nil { return nil, nil } - return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) + return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels) } -func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.LoginFunc == nil { return nil, nil } - return m.LoginFunc(serverKey, info, sshKey, dnsLabels) + return m.LoginFunc(info, sshKey, dnsLabels) } -func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { +func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { if m.GetDeviceAuthorizationFlowFunc == nil { return nil, nil } - return m.GetDeviceAuthorizationFlowFunc(serverKey) + return m.GetDeviceAuthorizationFlowFunc() } -func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { +func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) { if m.GetPKCEAuthorizationFlowFunc == nil { return nil, nil } - return m.GetPKCEAuthorizationFlow(serverKey) + return m.GetPKCEAuthorizationFlowFunc() } -// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface +func (m *MockClient) HealthCheck() error { + if m.HealthCheckFunc == nil { + return nil + } + return m.HealthCheckFunc() +} + +// GetNetworkMap mock implementation of GetNetworkMap from Client interface. func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) { return nil, nil } +// GetServerURL mock implementation of GetServerURL from mgm.Client interface +func (m *MockClient) GetServerURL() string { + if m.GetServerURLFunc == nil { + return "" + } + return m.GetServerURLFunc() +} + func (m *MockClient) SyncMeta(sysInfo *system.Info) error { if m.SyncMetaFunc == nil { return nil @@ -101,3 +112,24 @@ func (m *MockClient) Logout() error { } return m.LogoutFunc() } + +func (m *MockClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) { + if m.CreateExposeFunc == nil { + return nil, nil + } + return m.CreateExposeFunc(ctx, req) +} + +func (m *MockClient) RenewExpose(ctx context.Context, domain string) error { + if m.RenewExposeFunc == nil { + return nil + } + return m.RenewExposeFunc(ctx, domain) +} + +func (m *MockClient) StopExpose(ctx context.Context, domain string) error { + if m.StopExposeFunc == nil { + return nil + } + return m.StopExposeFunc(ctx, domain) +} diff --git a/shared/management/client/rest/azure_idp.go b/shared/management/client/rest/azure_idp.go new file mode 100644 index 000000000..40b90bc30 --- /dev/null +++ b/shared/management/client/rest/azure_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// AzureIDPAPI APIs for Azure AD IDP integrations +type AzureIDPAPI struct { + c *Client +} + +// List retrieves all Azure AD IDP integrations +func (a *AzureIDPAPI) List(ctx context.Context) ([]api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.AzureIntegration](resp) + return ret, err +} + +// Get retrieves a specific Azure AD IDP integration by ID +func (a *AzureIDPAPI) Get(ctx context.Context, integrationID string) (*api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Create creates a new Azure AD IDP integration +func (a *AzureIDPAPI) Create(ctx context.Context, request api.CreateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Update updates an existing Azure AD IDP integration +func (a *AzureIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/azure-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Delete deletes an Azure AD IDP integration +func (a *AzureIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for an Azure AD IDP integration +func (a *AzureIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Azure AD IDP integration +func (a *AzureIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/azure_idp_test.go b/shared/management/client/rest/azure_idp_test.go new file mode 100644 index 000000000..480d2a313 --- /dev/null +++ b/shared/management/client/rest/azure_idp_test.go @@ -0,0 +1,252 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testAzureIntegration = api.AzureIntegration{ + Id: 1, + Enabled: true, + ClientId: "12345678-1234-1234-1234-123456789012", + TenantId: "87654321-4321-4321-4321-210987654321", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + Host: "microsoft.com", + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestAzureIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.AzureIntegration{testAzureIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testAzureIntegration, ret[0]) + }) +} + +func TestAzureIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestAzureIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "12345678-1234-1234-1234-123456789012", req.ClientId) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{ + ClientId: "12345678-1234-1234-1234-123456789012", + ClientSecret: "secret", + TenantId: "87654321-4321-4321-4321-210987654321", + Host: api.CreateAzureIntegrationRequestHostMicrosoftCom, + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestAzureIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestAzureIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestAzureIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestAzureIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index 99d8eb594..f0cb4d2d1 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -11,6 +11,26 @@ import ( "github.com/netbirdio/netbird/shared/management/http/util" ) +// APIError represents an error response from the management API. +type APIError struct { + StatusCode int + Message string +} + +// Error implements the error interface. +func (e *APIError) Error() string { + return e.Message +} + +// IsNotFound returns true if the error represents a 404 Not Found response. +func IsNotFound(err error) bool { + var apiErr *APIError + if ok := errors.As(err, &apiErr); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + // Client Management service HTTP REST API Client type Client struct { managementURL string @@ -90,6 +110,15 @@ type Client struct { // see more: https://docs.netbird.io/api/resources/scim SCIM *SCIMAPI + // GoogleIDP NetBird Google Workspace IDP integration APIs + GoogleIDP *GoogleIDPAPI + + // AzureIDP NetBird Azure AD IDP integration APIs + AzureIDP *AzureIDPAPI + + // OktaScimIDP NetBird Okta SCIM IDP integration APIs + OktaScimIDP *OktaScimIDPAPI + // EventStreaming NetBird Event Streaming integration APIs // see more: https://docs.netbird.io/api/resources/event-streaming EventStreaming *EventStreamingAPI @@ -105,6 +134,15 @@ type Client struct { // Instance NetBird Instance API // see more: https://docs.netbird.io/api/resources/instance Instance *InstanceAPI + + // ReverseProxyServices NetBird reverse proxy services APIs + ReverseProxyServices *ReverseProxyServicesAPI + + // ReverseProxyClusters NetBird reverse proxy clusters APIs + ReverseProxyClusters *ReverseProxyClustersAPI + + // ReverseProxyDomains NetBird reverse proxy domains APIs + ReverseProxyDomains *ReverseProxyDomainsAPI } // New initialize new Client instance using PAT token @@ -156,10 +194,16 @@ func (c *Client) initialize() { c.MSP = &MSPAPI{c} c.EDR = &EDRAPI{c} c.SCIM = &SCIMAPI{c} + c.GoogleIDP = &GoogleIDPAPI{c} + c.AzureIDP = &AzureIDPAPI{c} + c.OktaScimIDP = &OktaScimIDPAPI{c} c.EventStreaming = &EventStreamingAPI{c} c.IdentityProviders = &IdentityProvidersAPI{c} c.Ingress = &IngressAPI{c} c.Instance = &InstanceAPI{c} + c.ReverseProxyServices = &ReverseProxyServicesAPI{c} + c.ReverseProxyClusters = &ReverseProxyClustersAPI{c} + c.ReverseProxyDomains = &ReverseProxyDomainsAPI{c} } // NewRequest creates and executes new management API request @@ -194,10 +238,12 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re if resp.StatusCode > 299 { parsedErr, pErr := parseResponse[util.ErrorResponse](resp) if pErr != nil { - return nil, pErr } - return nil, errors.New(parsedErr.Message) + return nil, &APIError{ + StatusCode: resp.StatusCode, + Message: parsedErr.Message, + } } return resp, nil diff --git a/shared/management/client/rest/edr.go b/shared/management/client/rest/edr.go index 7dfc891c2..f9b7f2a88 100644 --- a/shared/management/client/rest/edr.go +++ b/shared/management/client/rest/edr.go @@ -265,6 +265,65 @@ func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error { return nil } +// GetFleetDMIntegration retrieves the EDR FleetDM integration. +func (a *EDRAPI) GetFleetDMIntegration(ctx context.Context) (*api.EDRFleetDMResponse, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/fleetdm", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// CreateFleetDMIntegration creates a new EDR FleetDM integration. +func (a *EDRAPI) CreateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// UpdateFleetDMIntegration updates an existing EDR FleetDM integration. +func (a *EDRAPI) UpdateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.EDRFleetDMResponse](resp) + return &ret, err +} + +// DeleteFleetDMIntegration deletes the EDR FleetDM integration. +func (a *EDRAPI) DeleteFleetDMIntegration(ctx context.Context) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/fleetdm", nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + // BypassPeerCompliance bypasses compliance for a non-compliant peer // See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) { diff --git a/shared/management/client/rest/google_idp.go b/shared/management/client/rest/google_idp.go new file mode 100644 index 000000000..b86436503 --- /dev/null +++ b/shared/management/client/rest/google_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// GoogleIDPAPI APIs for Google Workspace IDP integrations +type GoogleIDPAPI struct { + c *Client +} + +// List retrieves all Google Workspace IDP integrations +func (a *GoogleIDPAPI) List(ctx context.Context) ([]api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.GoogleIntegration](resp) + return ret, err +} + +// Get retrieves a specific Google Workspace IDP integration by ID +func (a *GoogleIDPAPI) Get(ctx context.Context, integrationID string) (*api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Create creates a new Google Workspace IDP integration +func (a *GoogleIDPAPI) Create(ctx context.Context, request api.CreateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Update updates an existing Google Workspace IDP integration +func (a *GoogleIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/google-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Delete deletes a Google Workspace IDP integration +func (a *GoogleIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for a Google Workspace IDP integration +func (a *GoogleIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for a Google Workspace IDP integration +func (a *GoogleIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/google_idp_test.go b/shared/management/client/rest/google_idp_test.go new file mode 100644 index 000000000..03a6c161e --- /dev/null +++ b/shared/management/client/rest/google_idp_test.go @@ -0,0 +1,248 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testGoogleIntegration = api.GoogleIntegration{ + Id: 1, + Enabled: true, + CustomerId: "C01234567", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestGoogleIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.GoogleIntegration{testGoogleIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testGoogleIntegration, ret[0]) + }) +} + +func TestGoogleIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestGoogleIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "C01234567", req.CustomerId) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{ + CustomerId: "C01234567", + ServiceAccountKey: "key-data", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestGoogleIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestGoogleIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestGoogleIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestGoogleIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/okta_scim_idp.go b/shared/management/client/rest/okta_scim_idp.go new file mode 100644 index 000000000..eb677dae8 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// OktaScimIDPAPI APIs for Okta SCIM IDP integrations +type OktaScimIDPAPI struct { + c *Client +} + +// List retrieves all Okta SCIM IDP integrations +func (a *OktaScimIDPAPI) List(ctx context.Context) ([]api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.OktaScimIntegration](resp) + return ret, err +} + +// Get retrieves a specific Okta SCIM IDP integration by ID +func (a *OktaScimIDPAPI) Get(ctx context.Context, integrationID string) (*api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Create creates a new Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Create(ctx context.Context, request api.CreateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Update updates an existing Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/okta-scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Delete deletes an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// RegenerateToken regenerates the SCIM API token for an Okta SCIM integration +func (a *OktaScimIDPAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp/"+integrationID+"/token", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.ScimTokenResponse](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/okta_scim_idp_test.go b/shared/management/client/rest/okta_scim_idp_test.go new file mode 100644 index 000000000..d8d1f2b51 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp_test.go @@ -0,0 +1,246 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testOktaScimIntegration = api.OktaScimIntegration{ + Id: 1, + AuthToken: "****", + Enabled: true, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestOktaScimIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.OktaScimIntegration{testOktaScimIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testOktaScimIntegration, ret[0]) + }) +} + +func TestOktaScimIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestOktaScimIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "my-okta-connection", req.ConnectionName) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{ + ConnectionName: "my-okta-connection", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestOktaScimIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestOktaScimIDP_RegenerateToken_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(testScimToken) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testScimToken, *ret) + }) +} + +func TestOktaScimIDP_RegenerateToken_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestOktaScimIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/reverse_proxy_clusters.go b/shared/management/client/rest/reverse_proxy_clusters.go new file mode 100644 index 000000000..b55cd35a3 --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_clusters.go @@ -0,0 +1,25 @@ +package rest + +import ( + "context" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// ReverseProxyClustersAPI APIs for Reverse Proxy Clusters, do not use directly +type ReverseProxyClustersAPI struct { + c *Client +} + +// List lists all available proxy clusters +func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/clusters", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.ProxyCluster](resp) + return ret, err +} diff --git a/shared/management/client/rest/reverse_proxy_domains.go b/shared/management/client/rest/reverse_proxy_domains.go new file mode 100644 index 000000000..7066a0632 --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_domains.go @@ -0,0 +1,72 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + "net/url" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// ReverseProxyDomainsAPI APIs for Reverse Proxy Domains, do not use directly +type ReverseProxyDomainsAPI struct { + c *Client +} + +// List lists all reverse proxy domains +func (a *ReverseProxyDomainsAPI) List(ctx context.Context) ([]api.ReverseProxyDomain, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.ReverseProxyDomain](resp) + return ret, err +} + +// Create creates a new custom domain +func (a *ReverseProxyDomainsAPI) Create(ctx context.Context, request api.PostApiReverseProxiesDomainsJSONRequestBody) (*api.ReverseProxyDomain, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/domains", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.ReverseProxyDomain](resp) + if err != nil { + return nil, err + } + return &ret, nil +} + +// Delete deletes a custom domain +func (a *ReverseProxyDomainsAPI) Delete(ctx context.Context, domainID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/domains/"+url.PathEscape(domainID), nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Validate triggers domain ownership validation for a custom domain +func (a *ReverseProxyDomainsAPI) Validate(ctx context.Context, domainID string) error { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains/"+url.PathEscape(domainID)+"/validate", nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} diff --git a/shared/management/client/rest/reverse_proxy_services.go b/shared/management/client/rest/reverse_proxy_services.go new file mode 100644 index 000000000..2ecb382b2 --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_services.go @@ -0,0 +1,97 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + "net/url" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// ReverseProxyServicesAPI APIs for Reverse Proxy Services, do not use directly +type ReverseProxyServicesAPI struct { + c *Client +} + +// List lists all reverse proxy services +func (a *ReverseProxyServicesAPI) List(ctx context.Context) ([]api.Service, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.Service](resp) + return ret, err +} + +// Get retrieves a reverse proxy service by ID +func (a *ReverseProxyServicesAPI) Get(ctx context.Context, serviceID string) (*api.Service, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Service](resp) + if err != nil { + return nil, err + } + return &ret, nil +} + +// Create creates a new reverse proxy service +func (a *ReverseProxyServicesAPI) Create(ctx context.Context, request api.PostApiReverseProxiesServicesJSONRequestBody) (*api.Service, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/services", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Service](resp) + if err != nil { + return nil, err + } + return &ret, nil +} + +// Update updates a reverse proxy service +func (a *ReverseProxyServicesAPI) Update(ctx context.Context, serviceID string, request api.PutApiReverseProxiesServicesServiceIdJSONRequestBody) (*api.Service, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Service](resp) + if err != nil { + return nil, err + } + return &ret, nil +} + +// Delete deletes a reverse proxy service +func (a *ReverseProxyServicesAPI) Delete(ctx context.Context, serviceID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + + return nil +} diff --git a/shared/management/client/rest/reverse_proxy_services_test.go b/shared/management/client/rest/reverse_proxy_services_test.go new file mode 100644 index 000000000..164563e97 --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_services_test.go @@ -0,0 +1,271 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testServiceTarget = api.ServiceTarget{ + TargetId: "peer-123", + TargetType: "peer", + Protocol: "https", + Port: 8443, + Enabled: true, +} + +var testService = api.Service{ + Id: "svc-1", + Name: "test-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Meta: api.ServiceMeta{ + CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), + Status: "active", + }, + Targets: []api.ServiceTarget{testServiceTarget}, +} + +func TestReverseProxyServices_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal([]api.Service{testService}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.List(context.Background()) + require.NoError(t, err) + require.Len(t, ret, 1) + assert.Equal(t, testService.Id, ret[0].Id) + assert.Equal(t, testService.Name, ret[0].Name) + }) +} + +func TestReverseProxyServices_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestReverseProxyServices_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(testService) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1") + require.NoError(t, err) + assert.Equal(t, testService.Id, ret.Id) + assert.Equal(t, testService.Domain, ret.Domain) + }) +} + +func TestReverseProxyServices_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1") + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestReverseProxyServices_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.ServiceRequest + require.NoError(t, json.Unmarshal(reqBytes, &req)) + assert.Equal(t, "test-service", req.Name) + assert.Equal(t, "test.example.com", req.Domain) + retBytes, _ := json.Marshal(testService) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{ + Name: "test-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Targets: []api.ServiceTarget{testServiceTarget}, + }) + require.NoError(t, err) + assert.Equal(t, testService.Id, ret.Id) + }) +} + +func TestReverseProxyServices_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{ + Name: "test-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Targets: []api.ServiceTarget{testServiceTarget}, + }) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.ServiceRequest + require.NoError(t, json.Unmarshal(reqBytes, &req)) + + require.Len(t, req.Targets, 1) + target := req.Targets[0] + require.NotNil(t, target.Options, "options should be present") + opts := target.Options + require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present") + assert.True(t, *opts.SkipTlsVerify) + require.NotNil(t, opts.RequestTimeout, "request_timeout should be present") + assert.Equal(t, "30s", *opts.RequestTimeout) + require.NotNil(t, opts.PathRewrite, "path_rewrite should be present") + assert.Equal(t, api.ServiceTargetOptionsPathRewrite("preserve"), *opts.PathRewrite) + require.NotNil(t, opts.CustomHeaders, "custom_headers should be present") + assert.Equal(t, "bar", (*opts.CustomHeaders)["X-Foo"]) + + retBytes, _ := json.Marshal(testService) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + + pathRewrite := api.ServiceTargetOptionsPathRewrite("preserve") + ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{ + Name: "test-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Targets: []api.ServiceTarget{ + { + TargetId: "peer-123", + TargetType: "peer", + Protocol: "https", + Port: 8443, + Enabled: true, + Options: &api.ServiceTargetOptions{ + SkipTlsVerify: ptr(true), + RequestTimeout: ptr("30s"), + PathRewrite: &pathRewrite, + CustomHeaders: &map[string]string{"X-Foo": "bar"}, + }, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, testService.Id, ret.Id) + }) +} + +func TestReverseProxyServices_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.ServiceRequest + require.NoError(t, json.Unmarshal(reqBytes, &req)) + assert.Equal(t, "updated-service", req.Name) + retBytes, _ := json.Marshal(testService) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{ + Name: "updated-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Targets: []api.ServiceTarget{testServiceTarget}, + }) + require.NoError(t, err) + assert.Equal(t, testService.Id, ret.Id) + }) +} + +func TestReverseProxyServices_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{ + Name: "updated-service", + Domain: "test.example.com", + Enabled: true, + Auth: api.ServiceAuthConfig{}, + Targets: []api.ServiceTarget{testServiceTarget}, + }) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestReverseProxyServices_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.ReverseProxyServices.Delete(context.Background(), "svc-1") + require.NoError(t, err) + }) +} + +func TestReverseProxyServices_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.ReverseProxyServices.Delete(context.Background(), "svc-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index b0ce1b5cc..0b855db67 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -68,8 +68,17 @@ tags: - name: MSP description: MSP portal for Tenant management. x-cloud-only: true - - name: IDP - description: Manage identity provider integrations for user and group sync. + - name: IDP SCIM Integrations + description: Manage generic SCIM identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Google Integrations + description: Manage Google Workspace identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Azure Integrations + description: Manage Azure AD identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Okta SCIM Integrations + description: Manage Okta SCIM identity provider integrations for user and group sync. x-cloud-only: true - name: EDR Intune Integrations description: Manage Microsoft Intune EDR integrations. @@ -83,12 +92,19 @@ tags: - name: EDR Huntress Integrations description: Manage Huntress EDR integrations. x-cloud-only: true + - name: EDR FleetDM Integrations + description: Manage FleetDM EDR integrations. + x-cloud-only: true - name: EDR Peers description: Manage EDR compliance bypass for peers. x-cloud-only: true - name: Event Streaming Integrations description: Manage event streaming integrations. x-cloud-only: true + - name: Notifications + description: Manage notification channels for account event alerts. + x-cloud-only: true + components: schemas: @@ -326,6 +342,16 @@ components: type: string format: cidr example: 100.64.0.0/16 + peer_expose_enabled: + description: Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. + type: boolean + example: false + peer_expose_groups: + description: Limits which peer groups are allowed to expose services. If empty, all peers are allowed when peer expose is enabled. + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 extra: $ref: '#/components/schemas/AccountExtraSettings' lazy_connection_enabled: @@ -337,6 +363,10 @@ components: description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") type: string example: "0.51.2" + auto_update_always: + description: When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + type: boolean + example: false embedded_idp_enabled: description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. type: boolean @@ -353,6 +383,8 @@ components: - peer_inactivity_expiration_enabled - peer_inactivity_expiration - regular_users_view_blocked + - peer_expose_enabled + - peer_expose_groups AccountExtraSettings: type: object properties: @@ -2810,6 +2842,29 @@ components: type: string description: "City name from geolocation" example: "San Francisco" + subdivision_code: + type: string + description: "First-level administrative subdivision ISO code (e.g. state/province)" + example: "CA" + bytes_upload: + type: integer + format: int64 + description: "Bytes uploaded (request body size)" + example: 1024 + bytes_download: + type: integer + format: int64 + description: "Bytes downloaded (response body size)" + example: 8192 + protocol: + type: string + description: "Protocol type: http, tcp, or udp" + example: "http" + metadata: + type: object + additionalProperties: + type: string + description: "Extra context about the request (e.g. crowdsec_verdict)" required: - id - service_id @@ -2819,6 +2874,8 @@ components: - path - duration_ms - status_code + - bytes_upload + - bytes_download ProxyAccessLogsResponse: type: object properties: @@ -2920,12 +2977,32 @@ components: id: type: string description: Service ID + example: "cs8i4ug6lnn4g9hqv7mg" name: type: string description: Service name + example: "myapp.example.netbird.app" domain: type: string description: Domain for the service + example: "myapp.example.netbird.app" + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + example: "http" + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only) + example: 8443 + port_auto_assigned: + type: boolean + description: Whether the listen port was auto-assigned + readOnly: true + example: false proxy_cluster: type: string description: The proxy cluster handling this service (derived from domain) @@ -2938,14 +3015,24 @@ components: enabled: type: boolean description: Whether the service is enabled + example: true + terminated: + type: boolean + description: Whether the service has been terminated. Terminated services cannot be updated. Services that violate the Terms of Service will be terminated. + readOnly: true + example: false pass_host_header: type: boolean description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address + example: false rewrite_redirects: type: boolean description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain + example: false auth: $ref: '#/components/schemas/ServiceAuthConfig' + access_restrictions: + $ref: '#/components/schemas/AccessRestrictions' meta: $ref: '#/components/schemas/ServiceMeta' required: @@ -2989,9 +3076,23 @@ components: name: type: string description: Service name + example: "myapp.example.netbird.app" domain: type: string description: Domain for the service + example: "myapp.example.netbird.app" + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + example: "http" + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + example: 5432 targets: type: array items: @@ -3001,46 +3102,94 @@ components: type: boolean description: Whether the service is enabled default: true + example: true pass_host_header: type: boolean description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address + example: false rewrite_redirects: type: boolean description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain + example: false auth: $ref: '#/components/schemas/ServiceAuthConfig' + access_restrictions: + $ref: '#/components/schemas/AccessRestrictions' required: - name - domain - - targets - - auth - enabled + ServiceTargetOptions: + type: object + properties: + skip_tls_verify: + type: boolean + description: Skip TLS certificate verification for this backend + example: false + request_timeout: + type: string + description: Per-target response timeout as a Go duration string (e.g. "30s", "2m") + example: "30s" + path_rewrite: + type: string + description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. + enum: [preserve] + example: "preserve" + custom_headers: + type: object + description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected. + propertyNames: + type: string + pattern: '^[!#$%&''*+.^_`|~0-9A-Za-z-]+$' + additionalProperties: + type: string + pattern: '^[^\r\n]*$' + example: {"X-Custom-Header": "value"} + proxy_protocol: + type: boolean + description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) + example: false + session_idle_timeout: + type: string + description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). + example: "2m" ServiceTarget: type: object properties: target_id: type: string description: Target ID + example: "cs8i4ug6lnn4g9hqv7mg" target_type: type: string - description: Target type (e.g., "peer", "resource") - enum: [peer, resource] + description: Target type + enum: [peer, host, domain, subnet] + example: "subnet" path: type: string - description: URL path prefix for this target + description: URL path prefix for this target (HTTP only) + example: "/" protocol: type: string description: Protocol to use when connecting to the backend - enum: [http, https] + enum: [http, https, tcp, udp] + example: "http" host: type: string description: Backend ip or domain for this target + example: "10.10.0.1" port: type: integer - description: Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + minimum: 1 + maximum: 65535 + description: Backend port for this target + example: 8080 enabled: type: boolean description: Whether this target is enabled + example: true + options: + $ref: '#/components/schemas/ServiceTargetOptions' required: - target_id - target_type @@ -3058,15 +3207,81 @@ components: $ref: '#/components/schemas/BearerAuthConfig' link_auth: $ref: '#/components/schemas/LinkAuthConfig' + header_auths: + type: array + items: + $ref: '#/components/schemas/HeaderAuthConfig' + HeaderAuthConfig: + type: object + description: Static header-value authentication. The proxy checks that the named header matches the configured value. + properties: + enabled: + type: boolean + description: Whether header auth is enabled + example: true + header: + type: string + description: HTTP header name to check (e.g. "Authorization", "X-API-Key") + example: "X-API-Key" + value: + type: string + description: Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses. + example: "my-secret-api-key" + required: + - enabled + - header + - value + AccessRestrictions: + type: object + description: Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + properties: + allowed_cidrs: + type: array + items: + type: string + format: cidr + example: "192.168.1.0/24" + description: CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed. + blocked_cidrs: + type: array + items: + type: string + format: cidr + example: "10.0.0.0/8" + description: CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs. + allowed_countries: + type: array + items: + type: string + pattern: '^[a-zA-Z]{2}$' + example: "US" + description: ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted. + blocked_countries: + type: array + items: + type: string + pattern: '^[a-zA-Z]{2}$' + example: "DE" + description: ISO 3166-1 alpha-2 country codes to block. + crowdsec_mode: + type: string + enum: + - "off" + - "enforce" + - "observe" + default: "off" + description: CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. PasswordAuthConfig: type: object properties: enabled: type: boolean description: Whether password auth is enabled + example: true password: type: string description: Auth password + example: "s3cret" required: - enabled - password @@ -3076,9 +3291,11 @@ components: enabled: type: boolean description: Whether PIN auth is enabled + example: false pin: type: string description: PIN value + example: "1234" required: - enabled - pin @@ -3088,10 +3305,12 @@ components: enabled: type: boolean description: Whether bearer auth is enabled + example: true distribution_groups: type: array items: type: string + example: "ch8i4ug6lnn4g9hqv7mg" description: List of group IDs that can use bearer auth required: - enabled @@ -3101,6 +3320,7 @@ components: enabled: type: boolean description: Whether link auth is enabled + example: false required: - enabled ProxyCluster: @@ -3131,17 +3351,33 @@ components: id: type: string description: Domain ID + example: "ds8i4ug6lnn4g9hqv7mg" domain: type: string description: Domain name + example: "example.netbird.app" validated: type: boolean description: Whether the domain has been validated + example: true type: $ref: '#/components/schemas/ReverseProxyDomainType' target_cluster: type: string description: The proxy cluster this domain is validated against (only for custom domains) + example: "eu.proxy.netbird.io" + supports_custom_ports: + type: boolean + description: Whether the cluster supports binding arbitrary TCP/UDP ports + example: true + require_subdomain: + type: boolean + description: Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. + example: false + supports_crowdsec: + type: boolean + description: Whether the proxy cluster has CrowdSec configured + example: false required: - id - domain @@ -3153,9 +3389,11 @@ components: domain: type: string description: Domain name + example: "myapp.example.com" target_cluster: type: string description: The proxy cluster this domain should be validated against + example: "eu.proxy.netbird.io" required: - domain - target_cluster @@ -4058,75 +4296,129 @@ components: description: Status of agent firewall. Can be one of Disabled, Enabled, Pending Isolation, Isolated, Pending Release. example: "Enabled" - CreateScimIntegrationRequest: + EDRFleetDMRequest: type: object - description: Request payload for creating an SCIM IDP integration - required: - - prefix - - provider + description: Request payload for creating or updating a FleetDM EDR integration properties: - prefix: + api_url: type: string - description: The connection prefix used for the SCIM provider - provider: + description: FleetDM server URL + api_token: type: string - description: Name of the SCIM identity provider - group_prefixes: + description: FleetDM API token + groups: type: array - description: List of start_with string patterns for groups to sync + description: The Groups this integrations applies to items: type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] - UpdateScimIntegrationRequest: - type: object - description: Request payload for updating an SCIM IDP integration - properties: + last_synced_interval: + type: integer + description: The devices last sync requirement interval in hours. Minimum value is 24 hours + minimum: 24 enabled: type: boolean description: Indicates whether the integration is enabled - example: true - group_prefixes: - type: array - description: List of start_with string patterns for groups to sync - items: - type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] - ScimIntegration: + default: true + match_attributes: + $ref: '#/components/schemas/FleetDMMatchAttributes' + required: + - api_url + - api_token + - groups + - last_synced_interval + - match_attributes + EDRFleetDMResponse: type: object - description: Represents a SCIM IDP integration + description: Represents a FleetDM EDR integration configuration required: - id - - enabled - - provider - - group_prefixes - - user_group_prefixes - - auth_token + - account_id + - api_url + - created_by - last_synced_at + - created_at + - updated_at + - groups + - last_synced_interval + - match_attributes + - enabled properties: id: type: integer format: int64 - description: The unique identifier for the integration + description: The unique numeric identifier for the integration. example: 123 + account_id: + type: string + description: The identifier of the account this integration belongs to. + example: "ch8i4ug6lnn4g9hqv7l0" + api_url: + type: string + description: FleetDM server URL + last_synced_at: + type: string + format: date-time + description: Timestamp of when the integration was last synced. + example: "2023-05-15T10:30:00Z" + created_by: + type: string + description: The user id that created the integration + created_at: + type: string + format: date-time + description: Timestamp of when the integration was created. + example: "2023-05-15T10:30:00Z" + updated_at: + type: string + format: date-time + description: Timestamp of when the integration was last updated. + example: "2023-05-16T11:45:00Z" + groups: + type: array + description: List of groups + items: + $ref: '#/components/schemas/Group' + last_synced_interval: + type: integer + description: The devices last sync requirement interval in hours. enabled: type: boolean description: Indicates whether the integration is enabled - example: true - provider: - type: string - description: Name of the SCIM identity provider + default: true + match_attributes: + $ref: '#/components/schemas/FleetDMMatchAttributes' + + FleetDMMatchAttributes: + type: object + description: Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + additionalProperties: false + properties: + disk_encryption_enabled: + type: boolean + description: Whether disk encryption (FileVault/BitLocker) must be enabled on the host + failing_policies_count_max: + type: integer + description: Maximum number of allowed failing policies. Use 0 to require all policies to pass + minimum: 0 + example: 0 + vulnerable_software_count_max: + type: integer + description: Maximum number of allowed vulnerable software on the host + minimum: 0 + example: 0 + status_online: + type: boolean + description: Whether the host must be online (recently seen by Fleet) + required_policies: + type: array + description: List of FleetDM policy IDs that must be passing on the host. If any of these policies is failing, the host is non-compliant + items: + type: integer + example: [1, 5, 12] + + IntegrationSyncFilters: + type: object + properties: group_prefixes: type: array description: List of start_with string patterns for groups to sync @@ -4139,15 +4431,77 @@ components: items: type: string example: [ "Users" ] - auth_token: + connector_id: type: string - description: SCIM API token (full on creation, masked otherwise) - example: "nbs_abc***********************************" - last_synced_at: - type: string - format: date-time - description: Timestamp of when the integration was last synced - example: "2023-05-15T10:30:00Z" + description: DEX connector ID for embedded IDP setups + IntegrationEnabled: + type: object + properties: + enabled: + type: boolean + description: Whether the integration is enabled + example: true + CreateScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an SCIM IDP integration + required: + - prefix + - provider + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider + UpdateScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an SCIM IDP integration + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider + ScimIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a SCIM IDP integration + required: + - id + - enabled + - prefix + - provider + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 123 + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider + auth_token: + type: string + description: SCIM API token (full on creation, masked otherwise) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of when the integration was last synced + example: "2023-05-15T10:30:00Z" IdpIntegrationSyncLog: type: object description: Represents a synchronization log entry for an integration @@ -4185,6 +4539,346 @@ components: type: string description: The newly generated SCIM API token example: "nbs_F3f0d..." + CreateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating a Google Workspace IDP integration + required: + - service_account_key + - customer_id + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + example: "eyJ0eXBlIjoic2VydmljZV9hY2NvdW50Ii..." + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + UpdateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating a Google Workspace IDP integration. All fields are optional. + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + GoogleIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a Google Workspace IDP integration + required: + - id + - customer_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + customer_id: + type: string + description: Customer ID from Google Workspace + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Azure AD IDP integration + required: + - client_secret + - client_id + - tenant_id + - host + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + example: "c2VjcmV0..." + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + host: + type: string + description: Azure host domain for the Graph API + enum: + - microsoft.com + - microsoft.us + example: "microsoft.com" + UpdateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Azure AD IDP integration. All fields are optional. + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + client_id: + type: string + description: Azure AD application (client) ID + tenant_id: + type: string + description: Azure AD tenant ID + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + AzureIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Azure AD IDP integration + required: + - id + - client_id + - tenant_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - host + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + host: + type: string + description: Azure host domain for the Graph API + example: "microsoft.com" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Okta SCIM IDP integration + required: + - connection_name + properties: + connection_name: + type: string + description: The Okta enterprise connection name on Auth0 + example: "my-okta-connection" + UpdateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Okta SCIM IDP integration. All fields are optional. + OktaScimIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Okta SCIM IDP integration + required: + - id + - enabled + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + auth_token: + type: string + description: SCIM API token (full on creation/regeneration, masked on retrieval) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + SyncResult: + type: object + description: Response for a manual sync trigger + properties: + result: + type: string + example: "ok" + NotificationChannelType: + type: string + description: The type of notification channel. + enum: + - email + - webhook + example: "email" + NotificationEventType: + type: string + description: | + An activity event type code. See `GET /api/integrations/notifications/types` for the full list + of supported event types and their human-readable descriptions. + example: "user.join" + EmailTarget: + type: object + description: Target configuration for email notification channels. + properties: + emails: + type: array + description: List of email addresses to send notifications to. + minItems: 1 + items: + type: string + format: email + example: [ "admin@example.com", "ops@example.com" ] + required: + - emails + WebhookTarget: + type: object + description: Target configuration for webhook notification channels. + properties: + url: + type: string + format: uri + description: The webhook endpoint URL to send notifications to. + example: "https://hooks.example.com/netbird" + headers: + type: object + additionalProperties: + type: string + description: | + Custom HTTP headers sent with each webhook request. + Values are write-only; in GET responses all values are masked. + example: + Authorization: "Bearer token" + X-Webhook-Secret: "secret" + required: + - url + NotificationChannelRequest: + type: object + description: Request body for creating or updating a notification channel. + properties: + type: + $ref: '#/components/schemas/NotificationChannelType' + target: + description: | + Channel-specific target configuration. The shape depends on the `type` field: + - `email`: requires an `EmailTarget` object + - `webhook`: requires a `WebhookTarget` object + oneOf: + - $ref: '#/components/schemas/EmailTarget' + - $ref: '#/components/schemas/WebhookTarget' + event_types: + type: array + description: List of activity event type codes this channel subscribes to. + items: + $ref: '#/components/schemas/NotificationEventType' + example: [ "user.join", "peer.user.add", "peer.login.expire" ] + enabled: + type: boolean + description: Whether this notification channel is active. + example: true + required: + - type + - event_types + - enabled + NotificationChannelResponse: + type: object + description: A notification channel configuration. + properties: + id: + type: string + description: Unique identifier of the notification channel. + readOnly: true + example: "ch8i4ug6lnn4g9hqv7m0" + type: + $ref: '#/components/schemas/NotificationChannelType' + target: + description: | + Channel-specific target configuration. The shape depends on the `type` field: + - `email`: an `EmailTarget` object + - `webhook`: a `WebhookTarget` object + oneOf: + - $ref: '#/components/schemas/EmailTarget' + - $ref: '#/components/schemas/WebhookTarget' + event_types: + type: array + description: List of activity event type codes this channel subscribes to. + items: + $ref: '#/components/schemas/NotificationEventType' + example: [ "user.join", "peer.user.add", "peer.login.expire" ] + enabled: + type: boolean + description: Whether this notification channel is active. + example: true + required: + - id + - type + - event_types + - enabled + NotificationTypeEntry: + type: object + description: A map of event type codes to their human-readable descriptions. + additionalProperties: + type: string + example: + user.join: "User joined" BypassResponse: type: object description: Response for bypassed peer operations. @@ -4225,6 +4919,12 @@ components: requires_authentication: description: Requires authentication content: { } + conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' securitySchemes: BearerAuth: type: http @@ -8815,10 +9515,877 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp: + post: + tags: + - IDP Google Integrations + summary: Create Google IDP Integration + description: Creates a new Google Workspace IDP integration + operationId: createGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateGoogleIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Google Integrations + summary: Get All Google IDP Integrations + description: Retrieves all Google Workspace IDP integrations for the authenticated account + operationId: getAllGoogleIntegrations + responses: + '200': + description: A list of Google IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/GoogleIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google IDP Integration + description: Retrieves a Google IDP integration by ID. + operationId: getGoogleIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Google Integrations + summary: Update Google IDP Integration + description: Updates an existing Google Workspace IDP integration. + operationId: updateGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateGoogleIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Google Integrations + summary: Delete Google IDP Integration + description: Deletes a Google IDP integration by ID. + operationId: deleteGoogleIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Google Integrations + summary: Sync Google IDP Integration + description: Triggers a manual synchronization for a Google IDP integration. + operationId: syncGoogleIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google Integration Sync Logs + description: Retrieves synchronization logs for a Google IDP integration. + operationId: getGoogleIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp: + post: + tags: + - IDP Azure Integrations + summary: Create Azure IDP Integration + description: Creates a new Azure AD IDP integration + operationId: createAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateAzureIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Azure Integrations + summary: Get All Azure IDP Integrations + description: Retrieves all Azure AD IDP integrations for the authenticated account + operationId: getAllAzureIntegrations + responses: + '200': + description: A list of Azure IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AzureIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure IDP Integration + description: Retrieves an Azure IDP integration by ID. + operationId: getAzureIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Azure Integrations + summary: Update Azure IDP Integration + description: Updates an existing Azure AD IDP integration. + operationId: updateAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateAzureIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Azure Integrations + summary: Delete Azure IDP Integration + description: Deletes an Azure IDP integration by ID. + operationId: deleteAzureIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Azure Integrations + summary: Sync Azure IDP Integration + description: Triggers a manual synchronization for an Azure IDP integration. + operationId: syncAzureIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure Integration Sync Logs + description: Retrieves synchronization logs for an Azure IDP integration. + operationId: getAzureIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp: + post: + tags: + - IDP Okta SCIM Integrations + summary: Create Okta SCIM IDP Integration + description: Creates a new Okta SCIM IDP integration + operationId: createOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateOktaScimIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Okta SCIM Integrations + summary: Get All Okta SCIM IDP Integrations + description: Retrieves all Okta SCIM IDP integrations for the authenticated account + operationId: getAllOktaScimIntegrations + responses: + '200': + description: A list of Okta SCIM IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/OktaScimIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM IDP Integration + description: Retrieves an Okta SCIM IDP integration by ID. + operationId: getOktaScimIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Okta SCIM Integrations + summary: Update Okta SCIM IDP Integration + description: Updates an existing Okta SCIM IDP integration. + operationId: updateOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateOktaScimIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Okta SCIM Integrations + summary: Delete Okta SCIM IDP Integration + description: Deletes an Okta SCIM IDP integration by ID. + operationId: deleteOktaScimIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/token: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Okta SCIM Integrations + summary: Regenerate Okta SCIM Token + description: Regenerates the SCIM API token for an Okta SCIM IDP integration. + operationId: regenerateOktaScimToken + responses: + '200': + description: Token regenerated successfully. Returns the new token. + content: + application/json: + schema: + $ref: '#/components/schemas/ScimTokenResponse' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM Integration Sync Logs + description: Retrieves synchronization logs for an Okta SCIM IDP integration. + operationId: getOktaScimIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/integrations/scim-idp: post: tags: - - IDP + - IDP SCIM Integrations summary: Create SCIM IDP Integration description: Creates a new SCIM integration operationId: createSCIMIntegration @@ -8855,7 +10422,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' get: tags: - - IDP + - IDP SCIM Integrations summary: Get All SCIM IDP Integrations description: Retrieves all SCIM IDP integrations for the authenticated account operationId: getAllSCIMIntegrations @@ -8887,11 +10454,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM IDP Integration description: Retrieves an SCIM IDP integration by ID. operationId: getSCIMIntegration @@ -8928,7 +10496,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' put: tags: - - IDP + - IDP SCIM Integrations summary: Update SCIM IDP Integration description: Updates an existing SCIM IDP Integration. operationId: updateSCIMIntegration @@ -8971,7 +10539,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' delete: tags: - - IDP + - IDP SCIM Integrations summary: Delete SCIM IDP Integration description: Deletes an SCIM IDP integration by ID. operationId: deleteSCIMIntegration @@ -9014,11 +10582,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 post: tags: - - IDP + - IDP SCIM Integrations summary: Regenerate SCIM Token description: Regenerates the SCIM API token for an SCIM IDP integration. operationId: regenerateSCIMToken @@ -9060,11 +10629,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM Integration Sync Logs description: Retrieves synchronization logs for a SCIM IDP integration. operationId: getSCIMIntegrationLogs @@ -9257,6 +10827,161 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /api/integrations/edr/fleetdm: + post: + tags: + - EDR FleetDM Integrations + summary: Create EDR FleetDM Integration + description: Creates a new EDR FleetDM integration + operationId: createFleetDMEDRIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - EDR FleetDM Integrations + summary: Get EDR FleetDM Integration + description: Retrieves a specific EDR FleetDM integration by its ID. + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - EDR FleetDM Integrations + summary: Update EDR FleetDM Integration + description: Updates an existing EDR FleetDM Integration. + operationId: updateFleetDMEDRIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/EDRFleetDMResponse' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - EDR FleetDM Integrations + summary: Delete EDR FleetDM Integration + description: Deletes an EDR FleetDM Integration by its ID. + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/peers/{peer-id}/edr/bypass: parameters: - name: peer-id @@ -9569,6 +11294,29 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services @@ -9618,29 +11366,8 @@ paths: "$ref": "#/components/responses/requires_authentication" '403': "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - /api/reverse-proxies/clusters: - get: - summary: List available proxy clusters - description: Returns a list of available proxy clusters with their connection status - tags: [ Services ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - responses: - '200': - description: A JSON Array of proxy clusters - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/ProxyCluster' - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services/{serviceId}: @@ -9710,6 +11437,8 @@ paths: "$ref": "#/components/responses/forbidden" '404': "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" delete: @@ -9852,3 +11581,172 @@ paths: "$ref": "#/components/responses/not_found" '500': "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/types: + get: + tags: + - Notifications + summary: List Notification Event Types + description: | + Returns a map of all supported activity event type codes to their + human-readable descriptions. Use these codes when configuring + `event_types` on notification channels. + operationId: listNotificationEventTypes + responses: + '200': + description: A map of event type codes to descriptions. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationTypeEntry' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/channels: + get: + tags: + - Notifications + summary: List Notification Channels + description: Retrieves all notification channels configured for the authenticated account. + operationId: listNotificationChannels + responses: + '200': + description: A list of notification channels. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + post: + tags: + - Notifications + summary: Create Notification Channel + description: | + Creates a new notification channel for the authenticated account. + Supported channel types are `email` and `webhook`. + operationId: createNotificationChannel + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelRequest' + responses: + '200': + description: Notification channel created successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/channels/{channelId}: + parameters: + - name: channelId + in: path + required: true + description: The unique identifier of the notification channel. + schema: + type: string + example: "ch8i4ug6lnn4g9hqv7m0" + get: + tags: + - Notifications + summary: Get Notification Channel + description: Retrieves a specific notification channel by its ID. + operationId: getNotificationChannel + responses: + '200': + description: Successfully retrieved the notification channel. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + tags: + - Notifications + summary: Update Notification Channel + description: Updates an existing notification channel. + operationId: updateNotificationChannel + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelRequest' + responses: + '200': + description: Notification channel updated successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + delete: + tags: + - Notifications + summary: Delete Notification Channel + description: Deletes a notification channel by its ID. + operationId: deleteNotificationChannel + responses: + '200': + description: Notification channel deleted successfully. + content: + application/json: + schema: + type: object + example: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 7a7e75855..0317b8183 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT. +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT. package api import ( @@ -9,6 +9,7 @@ import ( "time" "github.com/oapi-codegen/runtime" + openapi_types "github.com/oapi-codegen/runtime/types" ) const ( @@ -16,6 +17,45 @@ const ( TokenAuthScopes = "TokenAuth.Scopes" ) +// Defines values for AccessRestrictionsCrowdsecMode. +const ( + AccessRestrictionsCrowdsecModeEnforce AccessRestrictionsCrowdsecMode = "enforce" + AccessRestrictionsCrowdsecModeObserve AccessRestrictionsCrowdsecMode = "observe" + AccessRestrictionsCrowdsecModeOff AccessRestrictionsCrowdsecMode = "off" +) + +// Valid indicates whether the value is a known member of the AccessRestrictionsCrowdsecMode enum. +func (e AccessRestrictionsCrowdsecMode) Valid() bool { + switch e { + case AccessRestrictionsCrowdsecModeEnforce: + return true + case AccessRestrictionsCrowdsecModeObserve: + return true + case AccessRestrictionsCrowdsecModeOff: + return true + default: + return false + } +} + +// Defines values for CreateAzureIntegrationRequestHost. +const ( + CreateAzureIntegrationRequestHostMicrosoftCom CreateAzureIntegrationRequestHost = "microsoft.com" + CreateAzureIntegrationRequestHostMicrosoftUs CreateAzureIntegrationRequestHost = "microsoft.us" +) + +// Valid indicates whether the value is a known member of the CreateAzureIntegrationRequestHost enum. +func (e CreateAzureIntegrationRequestHost) Valid() bool { + switch e { + case CreateAzureIntegrationRequestHostMicrosoftCom: + return true + case CreateAzureIntegrationRequestHostMicrosoftUs: + return true + default: + return false + } +} + // Defines values for CreateIntegrationRequestPlatform. const ( CreateIntegrationRequestPlatformDatadog CreateIntegrationRequestPlatform = "datadog" @@ -24,6 +64,22 @@ const ( CreateIntegrationRequestPlatformS3 CreateIntegrationRequestPlatform = "s3" ) +// Valid indicates whether the value is a known member of the CreateIntegrationRequestPlatform enum. +func (e CreateIntegrationRequestPlatform) Valid() bool { + switch e { + case CreateIntegrationRequestPlatformDatadog: + return true + case CreateIntegrationRequestPlatformFirehose: + return true + case CreateIntegrationRequestPlatformGenericHttp: + return true + case CreateIntegrationRequestPlatformS3: + return true + default: + return false + } +} + // Defines values for DNSRecordType. const ( DNSRecordTypeA DNSRecordType = "A" @@ -31,6 +87,20 @@ const ( DNSRecordTypeCNAME DNSRecordType = "CNAME" ) +// Valid indicates whether the value is a known member of the DNSRecordType enum. +func (e DNSRecordType) Valid() bool { + switch e { + case DNSRecordTypeA: + return true + case DNSRecordTypeAAAA: + return true + case DNSRecordTypeCNAME: + return true + default: + return false + } +} + // Defines values for EventActivityCode. const ( EventActivityCodeAccountCreate EventActivityCode = "account.create" @@ -147,12 +217,256 @@ const ( EventActivityCodeUserUnblock EventActivityCode = "user.unblock" ) +// Valid indicates whether the value is a known member of the EventActivityCode enum. +func (e EventActivityCode) Valid() bool { + switch e { + case EventActivityCodeAccountCreate: + return true + case EventActivityCodeAccountDelete: + return true + case EventActivityCodeAccountDnsDomainUpdate: + return true + case EventActivityCodeAccountNetworkRangeUpdate: + return true + case EventActivityCodeAccountPeerInactivityExpirationDisable: + return true + case EventActivityCodeAccountPeerInactivityExpirationEnable: + return true + case EventActivityCodeAccountPeerInactivityExpirationUpdate: + return true + case EventActivityCodeAccountSettingGroupPropagationDisable: + return true + case EventActivityCodeAccountSettingGroupPropagationEnable: + return true + case EventActivityCodeAccountSettingLazyConnectionDisable: + return true + case EventActivityCodeAccountSettingLazyConnectionEnable: + return true + case EventActivityCodeAccountSettingPeerApprovalDisable: + return true + case EventActivityCodeAccountSettingPeerApprovalEnable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationDisable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationEnable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationUpdate: + return true + case EventActivityCodeAccountSettingRoutingPeerDnsResolutionDisable: + return true + case EventActivityCodeAccountSettingRoutingPeerDnsResolutionEnable: + return true + case EventActivityCodeAccountSettingsAutoVersionUpdate: + return true + case EventActivityCodeDashboardLogin: + return true + case EventActivityCodeDnsSettingDisabledManagementGroupAdd: + return true + case EventActivityCodeDnsSettingDisabledManagementGroupDelete: + return true + case EventActivityCodeDnsZoneCreate: + return true + case EventActivityCodeDnsZoneDelete: + return true + case EventActivityCodeDnsZoneRecordCreate: + return true + case EventActivityCodeDnsZoneRecordDelete: + return true + case EventActivityCodeDnsZoneRecordUpdate: + return true + case EventActivityCodeDnsZoneUpdate: + return true + case EventActivityCodeGroupAdd: + return true + case EventActivityCodeGroupDelete: + return true + case EventActivityCodeGroupUpdate: + return true + case EventActivityCodeIdentityproviderCreate: + return true + case EventActivityCodeIdentityproviderDelete: + return true + case EventActivityCodeIdentityproviderUpdate: + return true + case EventActivityCodeIntegrationCreate: + return true + case EventActivityCodeIntegrationDelete: + return true + case EventActivityCodeIntegrationUpdate: + return true + case EventActivityCodeNameserverGroupAdd: + return true + case EventActivityCodeNameserverGroupDelete: + return true + case EventActivityCodeNameserverGroupUpdate: + return true + case EventActivityCodeNetworkCreate: + return true + case EventActivityCodeNetworkDelete: + return true + case EventActivityCodeNetworkResourceCreate: + return true + case EventActivityCodeNetworkResourceDelete: + return true + case EventActivityCodeNetworkResourceUpdate: + return true + case EventActivityCodeNetworkRouterCreate: + return true + case EventActivityCodeNetworkRouterDelete: + return true + case EventActivityCodeNetworkRouterUpdate: + return true + case EventActivityCodeNetworkUpdate: + return true + case EventActivityCodePeerApprovalRevoke: + return true + case EventActivityCodePeerApprove: + return true + case EventActivityCodePeerGroupAdd: + return true + case EventActivityCodePeerGroupDelete: + return true + case EventActivityCodePeerInactivityExpirationDisable: + return true + case EventActivityCodePeerInactivityExpirationEnable: + return true + case EventActivityCodePeerIpUpdate: + return true + case EventActivityCodePeerJobCreate: + return true + case EventActivityCodePeerLoginExpirationDisable: + return true + case EventActivityCodePeerLoginExpirationEnable: + return true + case EventActivityCodePeerLoginExpire: + return true + case EventActivityCodePeerRename: + return true + case EventActivityCodePeerSetupkeyAdd: + return true + case EventActivityCodePeerSshDisable: + return true + case EventActivityCodePeerSshEnable: + return true + case EventActivityCodePeerUserAdd: + return true + case EventActivityCodePersonalAccessTokenCreate: + return true + case EventActivityCodePersonalAccessTokenDelete: + return true + case EventActivityCodePolicyAdd: + return true + case EventActivityCodePolicyDelete: + return true + case EventActivityCodePolicyUpdate: + return true + case EventActivityCodePostureCheckCreate: + return true + case EventActivityCodePostureCheckDelete: + return true + case EventActivityCodePostureCheckUpdate: + return true + case EventActivityCodeResourceGroupAdd: + return true + case EventActivityCodeResourceGroupDelete: + return true + case EventActivityCodeRouteAdd: + return true + case EventActivityCodeRouteDelete: + return true + case EventActivityCodeRouteUpdate: + return true + case EventActivityCodeRuleAdd: + return true + case EventActivityCodeRuleDelete: + return true + case EventActivityCodeRuleUpdate: + return true + case EventActivityCodeServiceCreate: + return true + case EventActivityCodeServiceDelete: + return true + case EventActivityCodeServiceUpdate: + return true + case EventActivityCodeServiceUserCreate: + return true + case EventActivityCodeServiceUserDelete: + return true + case EventActivityCodeSetupkeyAdd: + return true + case EventActivityCodeSetupkeyDelete: + return true + case EventActivityCodeSetupkeyGroupAdd: + return true + case EventActivityCodeSetupkeyGroupDelete: + return true + case EventActivityCodeSetupkeyOveruse: + return true + case EventActivityCodeSetupkeyRevoke: + return true + case EventActivityCodeSetupkeyUpdate: + return true + case EventActivityCodeTransferredOwnerRole: + return true + case EventActivityCodeUserApprove: + return true + case EventActivityCodeUserBlock: + return true + case EventActivityCodeUserCreate: + return true + case EventActivityCodeUserDelete: + return true + case EventActivityCodeUserGroupAdd: + return true + case EventActivityCodeUserGroupDelete: + return true + case EventActivityCodeUserInvite: + return true + case EventActivityCodeUserInviteLinkAccept: + return true + case EventActivityCodeUserInviteLinkCreate: + return true + case EventActivityCodeUserInviteLinkDelete: + return true + case EventActivityCodeUserInviteLinkRegenerate: + return true + case EventActivityCodeUserJoin: + return true + case EventActivityCodeUserPasswordChange: + return true + case EventActivityCodeUserPeerDelete: + return true + case EventActivityCodeUserPeerLogin: + return true + case EventActivityCodeUserReject: + return true + case EventActivityCodeUserRoleUpdate: + return true + case EventActivityCodeUserUnblock: + return true + default: + return false + } +} + // Defines values for GeoLocationCheckAction. const ( GeoLocationCheckActionAllow GeoLocationCheckAction = "allow" GeoLocationCheckActionDeny GeoLocationCheckAction = "deny" ) +// Valid indicates whether the value is a known member of the GeoLocationCheckAction enum. +func (e GeoLocationCheckAction) Valid() bool { + switch e { + case GeoLocationCheckActionAllow: + return true + case GeoLocationCheckActionDeny: + return true + default: + return false + } +} + // Defines values for GroupIssued. const ( GroupIssuedApi GroupIssued = "api" @@ -160,6 +474,20 @@ const ( GroupIssuedJwt GroupIssued = "jwt" ) +// Valid indicates whether the value is a known member of the GroupIssued enum. +func (e GroupIssued) Valid() bool { + switch e { + case GroupIssuedApi: + return true + case GroupIssuedIntegration: + return true + case GroupIssuedJwt: + return true + default: + return false + } +} + // Defines values for GroupMinimumIssued. const ( GroupMinimumIssuedApi GroupMinimumIssued = "api" @@ -167,6 +495,20 @@ const ( GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" ) +// Valid indicates whether the value is a known member of the GroupMinimumIssued enum. +func (e GroupMinimumIssued) Valid() bool { + switch e { + case GroupMinimumIssuedApi: + return true + case GroupMinimumIssuedIntegration: + return true + case GroupMinimumIssuedJwt: + return true + default: + return false + } +} + // Defines values for IdentityProviderType. const ( IdentityProviderTypeEntra IdentityProviderType = "entra" @@ -178,6 +520,28 @@ const ( IdentityProviderTypeZitadel IdentityProviderType = "zitadel" ) +// Valid indicates whether the value is a known member of the IdentityProviderType enum. +func (e IdentityProviderType) Valid() bool { + switch e { + case IdentityProviderTypeEntra: + return true + case IdentityProviderTypeGoogle: + return true + case IdentityProviderTypeMicrosoft: + return true + case IdentityProviderTypeOidc: + return true + case IdentityProviderTypeOkta: + return true + case IdentityProviderTypePocketid: + return true + case IdentityProviderTypeZitadel: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationPortMappingProtocol. const ( IngressPortAllocationPortMappingProtocolTcp IngressPortAllocationPortMappingProtocol = "tcp" @@ -185,6 +549,20 @@ const ( IngressPortAllocationPortMappingProtocolUdp IngressPortAllocationPortMappingProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationPortMappingProtocol enum. +func (e IngressPortAllocationPortMappingProtocol) Valid() bool { + switch e { + case IngressPortAllocationPortMappingProtocolTcp: + return true + case IngressPortAllocationPortMappingProtocolTcpudp: + return true + case IngressPortAllocationPortMappingProtocolUdp: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationRequestDirectPortProtocol. const ( IngressPortAllocationRequestDirectPortProtocolTcp IngressPortAllocationRequestDirectPortProtocol = "tcp" @@ -192,6 +570,20 @@ const ( IngressPortAllocationRequestDirectPortProtocolUdp IngressPortAllocationRequestDirectPortProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationRequestDirectPortProtocol enum. +func (e IngressPortAllocationRequestDirectPortProtocol) Valid() bool { + switch e { + case IngressPortAllocationRequestDirectPortProtocolTcp: + return true + case IngressPortAllocationRequestDirectPortProtocolTcpudp: + return true + case IngressPortAllocationRequestDirectPortProtocolUdp: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationRequestPortRangeProtocol. const ( IngressPortAllocationRequestPortRangeProtocolTcp IngressPortAllocationRequestPortRangeProtocol = "tcp" @@ -199,6 +591,20 @@ const ( IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationRequestPortRangeProtocol enum. +func (e IngressPortAllocationRequestPortRangeProtocol) Valid() bool { + switch e { + case IngressPortAllocationRequestPortRangeProtocolTcp: + return true + case IngressPortAllocationRequestPortRangeProtocolTcpudp: + return true + case IngressPortAllocationRequestPortRangeProtocolUdp: + return true + default: + return false + } +} + // Defines values for IntegrationResponsePlatform. const ( IntegrationResponsePlatformDatadog IntegrationResponsePlatform = "datadog" @@ -207,12 +613,40 @@ const ( IntegrationResponsePlatformS3 IntegrationResponsePlatform = "s3" ) +// Valid indicates whether the value is a known member of the IntegrationResponsePlatform enum. +func (e IntegrationResponsePlatform) Valid() bool { + switch e { + case IntegrationResponsePlatformDatadog: + return true + case IntegrationResponsePlatformFirehose: + return true + case IntegrationResponsePlatformGenericHttp: + return true + case IntegrationResponsePlatformS3: + return true + default: + return false + } +} + // Defines values for InvoiceResponseType. const ( InvoiceResponseTypeAccount InvoiceResponseType = "account" InvoiceResponseTypeTenants InvoiceResponseType = "tenants" ) +// Valid indicates whether the value is a known member of the InvoiceResponseType enum. +func (e InvoiceResponseType) Valid() bool { + switch e { + case InvoiceResponseTypeAccount: + return true + case InvoiceResponseTypeTenants: + return true + default: + return false + } +} + // Defines values for JobResponseStatus. const ( JobResponseStatusFailed JobResponseStatus = "failed" @@ -220,11 +654,35 @@ const ( JobResponseStatusSucceeded JobResponseStatus = "succeeded" ) +// Valid indicates whether the value is a known member of the JobResponseStatus enum. +func (e JobResponseStatus) Valid() bool { + switch e { + case JobResponseStatusFailed: + return true + case JobResponseStatusPending: + return true + case JobResponseStatusSucceeded: + return true + default: + return false + } +} + // Defines values for NameserverNsType. const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Valid indicates whether the value is a known member of the NameserverNsType enum. +func (e NameserverNsType) Valid() bool { + switch e { + case NameserverNsTypeUdp: + return true + default: + return false + } +} + // Defines values for NetworkResourceType. const ( NetworkResourceTypeDomain NetworkResourceType = "domain" @@ -232,18 +690,74 @@ const ( NetworkResourceTypeSubnet NetworkResourceType = "subnet" ) +// Valid indicates whether the value is a known member of the NetworkResourceType enum. +func (e NetworkResourceType) Valid() bool { + switch e { + case NetworkResourceTypeDomain: + return true + case NetworkResourceTypeHost: + return true + case NetworkResourceTypeSubnet: + return true + default: + return false + } +} + +// Defines values for NotificationChannelType. +const ( + NotificationChannelTypeEmail NotificationChannelType = "email" + NotificationChannelTypeWebhook NotificationChannelType = "webhook" +) + +// Valid indicates whether the value is a known member of the NotificationChannelType enum. +func (e NotificationChannelType) Valid() bool { + switch e { + case NotificationChannelTypeEmail: + return true + case NotificationChannelTypeWebhook: + return true + default: + return false + } +} + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" PeerNetworkRangeCheckActionDeny PeerNetworkRangeCheckAction = "deny" ) +// Valid indicates whether the value is a known member of the PeerNetworkRangeCheckAction enum. +func (e PeerNetworkRangeCheckAction) Valid() bool { + switch e { + case PeerNetworkRangeCheckActionAllow: + return true + case PeerNetworkRangeCheckActionDeny: + return true + default: + return false + } +} + // Defines values for PolicyRuleAction. const ( PolicyRuleActionAccept PolicyRuleAction = "accept" PolicyRuleActionDrop PolicyRuleAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleAction enum. +func (e PolicyRuleAction) Valid() bool { + switch e { + case PolicyRuleActionAccept: + return true + case PolicyRuleActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleProtocol. const ( PolicyRuleProtocolAll PolicyRuleProtocol = "all" @@ -253,12 +767,42 @@ const ( PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleProtocol enum. +func (e PolicyRuleProtocol) Valid() bool { + switch e { + case PolicyRuleProtocolAll: + return true + case PolicyRuleProtocolIcmp: + return true + case PolicyRuleProtocolNetbirdSsh: + return true + case PolicyRuleProtocolTcp: + return true + case PolicyRuleProtocolUdp: + return true + default: + return false + } +} + // Defines values for PolicyRuleMinimumAction. const ( PolicyRuleMinimumActionAccept PolicyRuleMinimumAction = "accept" PolicyRuleMinimumActionDrop PolicyRuleMinimumAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleMinimumAction enum. +func (e PolicyRuleMinimumAction) Valid() bool { + switch e { + case PolicyRuleMinimumActionAccept: + return true + case PolicyRuleMinimumActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleMinimumProtocol. const ( PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" @@ -268,12 +812,42 @@ const ( PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleMinimumProtocol enum. +func (e PolicyRuleMinimumProtocol) Valid() bool { + switch e { + case PolicyRuleMinimumProtocolAll: + return true + case PolicyRuleMinimumProtocolIcmp: + return true + case PolicyRuleMinimumProtocolNetbirdSsh: + return true + case PolicyRuleMinimumProtocolTcp: + return true + case PolicyRuleMinimumProtocolUdp: + return true + default: + return false + } +} + // Defines values for PolicyRuleUpdateAction. const ( PolicyRuleUpdateActionAccept PolicyRuleUpdateAction = "accept" PolicyRuleUpdateActionDrop PolicyRuleUpdateAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleUpdateAction enum. +func (e PolicyRuleUpdateAction) Valid() bool { + switch e { + case PolicyRuleUpdateActionAccept: + return true + case PolicyRuleUpdateActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleUpdateProtocol. const ( PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" @@ -283,6 +857,24 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleUpdateProtocol enum. +func (e PolicyRuleUpdateProtocol) Valid() bool { + switch e { + case PolicyRuleUpdateProtocolAll: + return true + case PolicyRuleUpdateProtocolIcmp: + return true + case PolicyRuleUpdateProtocolNetbirdSsh: + return true + case PolicyRuleUpdateProtocolTcp: + return true + case PolicyRuleUpdateProtocolUdp: + return true + default: + return false + } +} + // Defines values for ResourceType. const ( ResourceTypeDomain ResourceType = "domain" @@ -291,12 +883,40 @@ const ( ResourceTypeSubnet ResourceType = "subnet" ) +// Valid indicates whether the value is a known member of the ResourceType enum. +func (e ResourceType) Valid() bool { + switch e { + case ResourceTypeDomain: + return true + case ResourceTypeHost: + return true + case ResourceTypePeer: + return true + case ResourceTypeSubnet: + return true + default: + return false + } +} + // Defines values for ReverseProxyDomainType. const ( ReverseProxyDomainTypeCustom ReverseProxyDomainType = "custom" ReverseProxyDomainTypeFree ReverseProxyDomainType = "free" ) +// Valid indicates whether the value is a known member of the ReverseProxyDomainType enum. +func (e ReverseProxyDomainType) Valid() bool { + switch e { + case ReverseProxyDomainTypeCustom: + return true + case ReverseProxyDomainTypeFree: + return true + default: + return false + } +} + // Defines values for SentinelOneMatchAttributesNetworkStatus. const ( SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected" @@ -304,6 +924,44 @@ const ( SentinelOneMatchAttributesNetworkStatusQuarantined SentinelOneMatchAttributesNetworkStatus = "quarantined" ) +// Valid indicates whether the value is a known member of the SentinelOneMatchAttributesNetworkStatus enum. +func (e SentinelOneMatchAttributesNetworkStatus) Valid() bool { + switch e { + case SentinelOneMatchAttributesNetworkStatusConnected: + return true + case SentinelOneMatchAttributesNetworkStatusDisconnected: + return true + case SentinelOneMatchAttributesNetworkStatusQuarantined: + return true + default: + return false + } +} + +// Defines values for ServiceMode. +const ( + ServiceModeHttp ServiceMode = "http" + ServiceModeTcp ServiceMode = "tcp" + ServiceModeTls ServiceMode = "tls" + ServiceModeUdp ServiceMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceMode enum. +func (e ServiceMode) Valid() bool { + switch e { + case ServiceModeHttp: + return true + case ServiceModeTcp: + return true + case ServiceModeTls: + return true + case ServiceModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceMetaStatus. const ( ServiceMetaStatusActive ServiceMetaStatus = "active" @@ -314,18 +972,113 @@ const ( ServiceMetaStatusTunnelNotCreated ServiceMetaStatus = "tunnel_not_created" ) +// Valid indicates whether the value is a known member of the ServiceMetaStatus enum. +func (e ServiceMetaStatus) Valid() bool { + switch e { + case ServiceMetaStatusActive: + return true + case ServiceMetaStatusCertificateFailed: + return true + case ServiceMetaStatusCertificatePending: + return true + case ServiceMetaStatusError: + return true + case ServiceMetaStatusPending: + return true + case ServiceMetaStatusTunnelNotCreated: + return true + default: + return false + } +} + +// Defines values for ServiceRequestMode. +const ( + ServiceRequestModeHttp ServiceRequestMode = "http" + ServiceRequestModeTcp ServiceRequestMode = "tcp" + ServiceRequestModeTls ServiceRequestMode = "tls" + ServiceRequestModeUdp ServiceRequestMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceRequestMode enum. +func (e ServiceRequestMode) Valid() bool { + switch e { + case ServiceRequestModeHttp: + return true + case ServiceRequestModeTcp: + return true + case ServiceRequestModeTls: + return true + case ServiceRequestModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceTargetProtocol. const ( ServiceTargetProtocolHttp ServiceTargetProtocol = "http" ServiceTargetProtocolHttps ServiceTargetProtocol = "https" + ServiceTargetProtocolTcp ServiceTargetProtocol = "tcp" + ServiceTargetProtocolUdp ServiceTargetProtocol = "udp" ) +// Valid indicates whether the value is a known member of the ServiceTargetProtocol enum. +func (e ServiceTargetProtocol) Valid() bool { + switch e { + case ServiceTargetProtocolHttp: + return true + case ServiceTargetProtocolHttps: + return true + case ServiceTargetProtocolTcp: + return true + case ServiceTargetProtocolUdp: + return true + default: + return false + } +} + // Defines values for ServiceTargetTargetType. const ( - ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" - ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource" + ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" + ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" + ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" + ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" ) +// Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. +func (e ServiceTargetTargetType) Valid() bool { + switch e { + case ServiceTargetTargetTypeDomain: + return true + case ServiceTargetTargetTypeHost: + return true + case ServiceTargetTargetTypePeer: + return true + case ServiceTargetTargetTypeSubnet: + return true + default: + return false + } +} + +// Defines values for ServiceTargetOptionsPathRewrite. +const ( + ServiceTargetOptionsPathRewritePreserve ServiceTargetOptionsPathRewrite = "preserve" +) + +// Valid indicates whether the value is a known member of the ServiceTargetOptionsPathRewrite enum. +func (e ServiceTargetOptionsPathRewrite) Valid() bool { + switch e { + case ServiceTargetOptionsPathRewritePreserve: + return true + default: + return false + } +} + // Defines values for TenantResponseStatus. const ( TenantResponseStatusActive TenantResponseStatus = "active" @@ -334,6 +1087,22 @@ const ( TenantResponseStatusPending TenantResponseStatus = "pending" ) +// Valid indicates whether the value is a known member of the TenantResponseStatus enum. +func (e TenantResponseStatus) Valid() bool { + switch e { + case TenantResponseStatusActive: + return true + case TenantResponseStatusExisting: + return true + case TenantResponseStatusInvited: + return true + case TenantResponseStatusPending: + return true + default: + return false + } +} + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -341,11 +1110,35 @@ const ( UserStatusInvited UserStatus = "invited" ) +// Valid indicates whether the value is a known member of the UserStatus enum. +func (e UserStatus) Valid() bool { + switch e { + case UserStatusActive: + return true + case UserStatusBlocked: + return true + case UserStatusInvited: + return true + default: + return false + } +} + // Defines values for WorkloadType. const ( WorkloadTypeBundle WorkloadType = "bundle" ) +// Valid indicates whether the value is a known member of the WorkloadType enum. +func (e WorkloadType) Valid() bool { + switch e { + case WorkloadTypeBundle: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsType. const ( GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" @@ -354,12 +1147,40 @@ const ( GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN GetApiEventsNetworkTrafficParamsType = "TYPE_UNKNOWN" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsType enum. +func (e GetApiEventsNetworkTrafficParamsType) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsTypeTYPEDROP: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPEEND: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPESTART: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsConnectionType. const ( GetApiEventsNetworkTrafficParamsConnectionTypeP2P GetApiEventsNetworkTrafficParamsConnectionType = "P2P" GetApiEventsNetworkTrafficParamsConnectionTypeROUTED GetApiEventsNetworkTrafficParamsConnectionType = "ROUTED" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsConnectionType enum. +func (e GetApiEventsNetworkTrafficParamsConnectionType) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsConnectionTypeP2P: + return true + case GetApiEventsNetworkTrafficParamsConnectionTypeROUTED: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsDirection. const ( GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN GetApiEventsNetworkTrafficParamsDirection = "DIRECTION_UNKNOWN" @@ -367,6 +1188,83 @@ const ( GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsDirection enum. +func (e GetApiEventsNetworkTrafficParamsDirection) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN: + return true + case GetApiEventsNetworkTrafficParamsDirectionEGRESS: + return true + case GetApiEventsNetworkTrafficParamsDirectionINGRESS: + return true + default: + return false + } +} + +// Defines values for GetApiEventsProxyParamsSortBy. +const ( + GetApiEventsProxyParamsSortByAuthMethod GetApiEventsProxyParamsSortBy = "auth_method" + GetApiEventsProxyParamsSortByDuration GetApiEventsProxyParamsSortBy = "duration" + GetApiEventsProxyParamsSortByHost GetApiEventsProxyParamsSortBy = "host" + GetApiEventsProxyParamsSortByMethod GetApiEventsProxyParamsSortBy = "method" + GetApiEventsProxyParamsSortByPath GetApiEventsProxyParamsSortBy = "path" + GetApiEventsProxyParamsSortByReason GetApiEventsProxyParamsSortBy = "reason" + GetApiEventsProxyParamsSortBySourceIp GetApiEventsProxyParamsSortBy = "source_ip" + GetApiEventsProxyParamsSortByStatusCode GetApiEventsProxyParamsSortBy = "status_code" + GetApiEventsProxyParamsSortByTimestamp GetApiEventsProxyParamsSortBy = "timestamp" + GetApiEventsProxyParamsSortByUrl GetApiEventsProxyParamsSortBy = "url" + GetApiEventsProxyParamsSortByUserId GetApiEventsProxyParamsSortBy = "user_id" +) + +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsSortBy enum. +func (e GetApiEventsProxyParamsSortBy) Valid() bool { + switch e { + case GetApiEventsProxyParamsSortByAuthMethod: + return true + case GetApiEventsProxyParamsSortByDuration: + return true + case GetApiEventsProxyParamsSortByHost: + return true + case GetApiEventsProxyParamsSortByMethod: + return true + case GetApiEventsProxyParamsSortByPath: + return true + case GetApiEventsProxyParamsSortByReason: + return true + case GetApiEventsProxyParamsSortBySourceIp: + return true + case GetApiEventsProxyParamsSortByStatusCode: + return true + case GetApiEventsProxyParamsSortByTimestamp: + return true + case GetApiEventsProxyParamsSortByUrl: + return true + case GetApiEventsProxyParamsSortByUserId: + return true + default: + return false + } +} + +// Defines values for GetApiEventsProxyParamsSortOrder. +const ( + GetApiEventsProxyParamsSortOrderAsc GetApiEventsProxyParamsSortOrder = "asc" + GetApiEventsProxyParamsSortOrderDesc GetApiEventsProxyParamsSortOrder = "desc" +) + +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsSortOrder enum. +func (e GetApiEventsProxyParamsSortOrder) Valid() bool { + switch e { + case GetApiEventsProxyParamsSortOrderAsc: + return true + case GetApiEventsProxyParamsSortOrderDesc: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsMethod. const ( GetApiEventsProxyParamsMethodDELETE GetApiEventsProxyParamsMethod = "DELETE" @@ -378,18 +1276,85 @@ const ( GetApiEventsProxyParamsMethodPUT GetApiEventsProxyParamsMethod = "PUT" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsMethod enum. +func (e GetApiEventsProxyParamsMethod) Valid() bool { + switch e { + case GetApiEventsProxyParamsMethodDELETE: + return true + case GetApiEventsProxyParamsMethodGET: + return true + case GetApiEventsProxyParamsMethodHEAD: + return true + case GetApiEventsProxyParamsMethodOPTIONS: + return true + case GetApiEventsProxyParamsMethodPATCH: + return true + case GetApiEventsProxyParamsMethodPOST: + return true + case GetApiEventsProxyParamsMethodPUT: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsStatus. const ( GetApiEventsProxyParamsStatusFailed GetApiEventsProxyParamsStatus = "failed" GetApiEventsProxyParamsStatusSuccess GetApiEventsProxyParamsStatus = "success" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsStatus enum. +func (e GetApiEventsProxyParamsStatus) Valid() bool { + switch e { + case GetApiEventsProxyParamsStatusFailed: + return true + case GetApiEventsProxyParamsStatusSuccess: + return true + default: + return false + } +} + // Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue. const ( PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept" PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "decline" ) +// Valid indicates whether the value is a known member of the PutApiIntegrationsMspTenantsIdInviteJSONBodyValue enum. +func (e PutApiIntegrationsMspTenantsIdInviteJSONBodyValue) Valid() bool { + switch e { + case PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept: + return true + case PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline: + return true + default: + return false + } +} + +// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. +type AccessRestrictions struct { + // AllowedCidrs CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed. + AllowedCidrs *[]string `json:"allowed_cidrs,omitempty"` + + // AllowedCountries ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted. + AllowedCountries *[]string `json:"allowed_countries,omitempty"` + + // BlockedCidrs CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs. + BlockedCidrs *[]string `json:"blocked_cidrs,omitempty"` + + // BlockedCountries ISO 3166-1 alpha-2 country codes to block. + BlockedCountries *[]string `json:"blocked_countries,omitempty"` + + // CrowdsecMode CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. + CrowdsecMode *AccessRestrictionsCrowdsecMode `json:"crowdsec_mode,omitempty"` +} + +// AccessRestrictionsCrowdsecMode CrowdSec IP reputation mode. Only available when the proxy cluster supports CrowdSec. +type AccessRestrictionsCrowdsecMode string + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // CityName Commonly used English name of the city @@ -481,6 +1446,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateAlways When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + AutoUpdateAlways *bool `json:"auto_update_always,omitempty"` + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") AutoUpdateVersion *string `json:"auto_update_version,omitempty"` @@ -512,6 +1480,12 @@ type AccountSettings struct { // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` + // PeerExposeEnabled Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. + PeerExposeEnabled bool `json:"peer_expose_enabled"` + + // PeerExposeGroups Limits which peer groups are allowed to expose services. If empty, all peers are allowed when peer expose is enabled. + PeerExposeGroups []string `json:"peer_expose_groups"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). PeerInactivityExpiration int `json:"peer_inactivity_expiration"` @@ -540,6 +1514,39 @@ type AvailablePorts struct { Udp int `json:"udp"` } +// AzureIntegration defines model for AzureIntegration. +type AzureIntegration struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Host Azure host domain for the Graph API + Host string `json:"host"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // BearerAuthConfig defines model for BearerAuthConfig. type BearerAuthConfig struct { // DistributionGroups List of group IDs that can use bearer auth @@ -566,7 +1573,7 @@ type BundleParameters struct { // BundleResult defines model for BundleResult. type BundleResult struct { - UploadKey *string `json:"upload_key"` + UploadKey *string `json:"upload_key,omitempty"` } // BundleWorkloadRequest defines model for BundleWorkloadRequest. @@ -647,6 +1654,57 @@ type Country struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country type CountryCode = string +// CreateAzureIntegrationRequest defines model for CreateAzureIntegrationRequest. +type CreateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret string `json:"client_secret"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // Host Azure host domain for the Graph API + Host CreateAzureIntegrationRequestHost `json:"host"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// CreateAzureIntegrationRequestHost Azure host domain for the Graph API +type CreateAzureIntegrationRequestHost string + +// CreateGoogleIntegrationRequest defines model for CreateGoogleIntegrationRequest. +type CreateGoogleIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId string `json:"customer_id"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey string `json:"service_account_key"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // CreateIntegrationRequest Request payload for creating a new event streaming integration. Also used as the structure for the PUT request body, but not all fields are applicable for updates (see PUT operation description). type CreateIntegrationRequest struct { // Config Platform-specific configuration as key-value pairs. For creation, all necessary credentials and settings must be provided. For updates, provide the fields to change or the entire new configuration. @@ -662,8 +1720,26 @@ type CreateIntegrationRequest struct { // CreateIntegrationRequestPlatform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend. type CreateIntegrationRequestPlatform string -// CreateScimIntegrationRequest Request payload for creating an SCIM IDP integration +// CreateOktaScimIntegrationRequest defines model for CreateOktaScimIntegrationRequest. +type CreateOktaScimIntegrationRequest struct { + // ConnectionName The Okta enterprise connection name on Auth0 + ConnectionName string `json:"connection_name"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// CreateScimIntegrationRequest defines model for CreateScimIntegrationRequest. type CreateScimIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` @@ -815,6 +1891,63 @@ type EDRFalconResponse struct { ZtaScoreThreshold int `json:"zta_score_threshold"` } +// EDRFleetDMRequest Request payload for creating or updating a FleetDM EDR integration +type EDRFleetDMRequest struct { + // ApiToken FleetDM API token + ApiToken string `json:"api_token"` + + // ApiUrl FleetDM server URL + ApiUrl string `json:"api_url"` + + // Enabled Indicates whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // Groups The Groups this integrations applies to + Groups []string `json:"groups"` + + // LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours + LastSyncedInterval int `json:"last_synced_interval"` + + // MatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + MatchAttributes FleetDMMatchAttributes `json:"match_attributes"` +} + +// EDRFleetDMResponse Represents a FleetDM EDR integration configuration +type EDRFleetDMResponse struct { + // AccountId The identifier of the account this integration belongs to. + AccountId string `json:"account_id"` + + // ApiUrl FleetDM server URL + ApiUrl string `json:"api_url"` + + // CreatedAt Timestamp of when the integration was created. + CreatedAt time.Time `json:"created_at"` + + // CreatedBy The user id that created the integration + CreatedBy string `json:"created_by"` + + // Enabled Indicates whether the integration is enabled + Enabled bool `json:"enabled"` + + // Groups List of groups + Groups []Group `json:"groups"` + + // Id The unique numeric identifier for the integration. + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of when the integration was last synced. + LastSyncedAt time.Time `json:"last_synced_at"` + + // LastSyncedInterval The devices last sync requirement interval in hours. + LastSyncedInterval int `json:"last_synced_interval"` + + // MatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly + MatchAttributes FleetDMMatchAttributes `json:"match_attributes"` + + // UpdatedAt Timestamp of when the integration was last updated. + UpdatedAt time.Time `json:"updated_at"` +} + // EDRHuntressRequest Request payload for creating or updating a EDR Huntress integration type EDRHuntressRequest struct { // ApiKey Huntress API key @@ -983,6 +2116,12 @@ type EDRSentinelOneResponse struct { UpdatedAt time.Time `json:"updated_at"` } +// EmailTarget Target configuration for email notification channels. +type EmailTarget struct { + // Emails List of email addresses to send notifications to. + Emails []openapi_types.Email `json:"emails"` +} + // ErrorResponse Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. type ErrorResponse struct { // Message A human-readable error message. @@ -1022,6 +2161,24 @@ type Event struct { // EventActivityCode The string code of the activity that occurred during the event type EventActivityCode string +// FleetDMMatchAttributes Attribute conditions to match when approving FleetDM hosts. Most attributes work with FleetDM's free/open-source version. Premium-only attributes are marked accordingly +type FleetDMMatchAttributes struct { + // DiskEncryptionEnabled Whether disk encryption (FileVault/BitLocker) must be enabled on the host + DiskEncryptionEnabled *bool `json:"disk_encryption_enabled,omitempty"` + + // FailingPoliciesCountMax Maximum number of allowed failing policies. Use 0 to require all policies to pass + FailingPoliciesCountMax *int `json:"failing_policies_count_max,omitempty"` + + // RequiredPolicies List of FleetDM policy IDs that must be passing on the host. If any of these policies is failing, the host is non-compliant + RequiredPolicies *[]int `json:"required_policies,omitempty"` + + // StatusOnline Whether the host must be online (recently seen by Fleet) + StatusOnline *bool `json:"status_online,omitempty"` + + // VulnerableSoftwareCountMax Maximum number of allowed vulnerable software on the host + VulnerableSoftwareCountMax *int `json:"vulnerable_software_count_max,omitempty"` +} + // GeoLocationCheck Posture check for geo location type GeoLocationCheck struct { // Action Action to take upon policy match @@ -1037,6 +2194,33 @@ type GeoLocationCheckAction string // GetTenantsResponse defines model for GetTenantsResponse. type GetTenantsResponse = []TenantResponse +// GoogleIntegration defines model for GoogleIntegration. +type GoogleIntegration struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace + CustomerId string `json:"customer_id"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // Group defines model for Group. type Group struct { // Id Group ID @@ -1093,6 +2277,18 @@ type GroupRequest struct { Resources *[]Resource `json:"resources,omitempty"` } +// HeaderAuthConfig Static header-value authentication. The proxy checks that the named header matches the configured value. +type HeaderAuthConfig struct { + // Enabled Whether header auth is enabled + Enabled bool `json:"enabled"` + + // Header HTTP header name to check (e.g. "Authorization", "X-API-Key") + Header string `json:"header"` + + // Value Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses. + Value string `json:"value"` +} + // HuntressMatchAttributes Attribute conditions to match when approving agents type HuntressMatchAttributes struct { // DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus. @@ -1316,6 +2512,12 @@ type InstanceVersionInfo struct { ManagementUpdateAvailable bool `json:"management_update_available"` } +// IntegrationEnabled defines model for IntegrationEnabled. +type IntegrationEnabled struct { + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` +} + // IntegrationResponse Represents an event streaming integration. type IntegrationResponse struct { // AccountId The identifier of the account this integration belongs to. @@ -1343,6 +2545,18 @@ type IntegrationResponse struct { // IntegrationResponsePlatform The event streaming platform. type IntegrationResponsePlatform string +// IntegrationSyncFilters defines model for IntegrationSyncFilters. +type IntegrationSyncFilters struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // InvoicePDFResponse defines model for InvoicePDFResponse. type InvoicePDFResponse struct { // Url URL to redirect the user to invoice. @@ -1374,9 +2588,9 @@ type JobRequest struct { // JobResponse defines model for JobResponse. type JobResponse struct { - CompletedAt *time.Time `json:"completed_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` CreatedAt time.Time `json:"created_at"` - FailedReason *string `json:"failed_reason"` + FailedReason *string `json:"failed_reason,omitempty"` Id string `json:"id"` Status JobResponseStatus `json:"status"` TriggeredBy string `json:"triggered_by"` @@ -1744,6 +2958,67 @@ type NetworkTrafficUser struct { Name string `json:"name"` } +// NotificationChannelRequest Request body for creating or updating a notification channel. +type NotificationChannelRequest struct { + // Enabled Whether this notification channel is active. + Enabled bool `json:"enabled"` + + // EventTypes List of activity event type codes this channel subscribes to. + EventTypes []NotificationEventType `json:"event_types"` + + // Target Channel-specific target configuration. The shape depends on the `type` field: + // - `email`: requires an `EmailTarget` object + // - `webhook`: requires a `WebhookTarget` object + Target *NotificationChannelRequest_Target `json:"target,omitempty"` + + // Type The type of notification channel. + Type NotificationChannelType `json:"type"` +} + +// NotificationChannelRequest_Target Channel-specific target configuration. The shape depends on the `type` field: +// - `email`: requires an `EmailTarget` object +// - `webhook`: requires a `WebhookTarget` object +type NotificationChannelRequest_Target struct { + union json.RawMessage +} + +// NotificationChannelResponse A notification channel configuration. +type NotificationChannelResponse struct { + // Enabled Whether this notification channel is active. + Enabled bool `json:"enabled"` + + // EventTypes List of activity event type codes this channel subscribes to. + EventTypes []NotificationEventType `json:"event_types"` + + // Id Unique identifier of the notification channel. + Id *string `json:"id,omitempty"` + + // Target Channel-specific target configuration. The shape depends on the `type` field: + // - `email`: an `EmailTarget` object + // - `webhook`: a `WebhookTarget` object + Target *NotificationChannelResponse_Target `json:"target,omitempty"` + + // Type The type of notification channel. + Type NotificationChannelType `json:"type"` +} + +// NotificationChannelResponse_Target Channel-specific target configuration. The shape depends on the `type` field: +// - `email`: an `EmailTarget` object +// - `webhook`: a `WebhookTarget` object +type NotificationChannelResponse_Target struct { + union json.RawMessage +} + +// NotificationChannelType The type of notification channel. +type NotificationChannelType string + +// NotificationEventType An activity event type code. See `GET /api/integrations/notifications/types` for the full list +// of supported event types and their human-readable descriptions. +type NotificationEventType = string + +// NotificationTypeEntry A map of event type codes to their human-readable descriptions. +type NotificationTypeEntry map[string]string + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -1762,6 +3037,30 @@ type OSVersionCheck struct { Windows *MinKernelVersionCheck `json:"windows,omitempty"` } +// OktaScimIntegration defines model for OktaScimIntegration. +type OktaScimIntegration struct { + // AuthToken SCIM API token (full on creation/regeneration, masked on retrieval) + AuthToken string `json:"auth_token"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // PINAuthConfig defines model for PINAuthConfig. type PINAuthConfig struct { // Enabled Whether PIN auth is enabled @@ -2387,6 +3686,12 @@ type ProxyAccessLog struct { // AuthMethodUsed Authentication method used (e.g., password, pin, oidc) AuthMethodUsed *string `json:"auth_method_used,omitempty"` + // BytesDownload Bytes downloaded (response body size) + BytesDownload int64 `json:"bytes_download"` + + // BytesUpload Bytes uploaded (request body size) + BytesUpload int64 `json:"bytes_upload"` + // CityName City name from geolocation CityName *string `json:"city_name,omitempty"` @@ -2402,12 +3707,18 @@ type ProxyAccessLog struct { // Id Unique identifier for the access log entry Id string `json:"id"` + // Metadata Extra context about the request (e.g. crowdsec_verdict) + Metadata *map[string]string `json:"metadata,omitempty"` + // Method HTTP method of the request Method string `json:"method"` // Path Path of the request Path string `json:"path"` + // Protocol Protocol type: http, tcp, or udp + Protocol *string `json:"protocol,omitempty"` + // Reason Reason for the request result (e.g., authentication failure) Reason *string `json:"reason,omitempty"` @@ -2420,6 +3731,9 @@ type ProxyAccessLog struct { // StatusCode HTTP status code returned StatusCode int `json:"status_code"` + // SubdivisionCode First-level administrative subdivision ISO code (e.g. state/province) + SubdivisionCode *string `json:"subdivision_code,omitempty"` + // Timestamp Timestamp when the request was made Timestamp time.Time `json:"timestamp"` @@ -2472,6 +3786,15 @@ type ReverseProxyDomain struct { // Id Domain ID Id string `json:"id"` + // RequireSubdomain Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. + RequireSubdomain *bool `json:"require_subdomain,omitempty"` + + // SupportsCrowdsec Whether the proxy cluster has CrowdSec configured + SupportsCrowdsec *bool `json:"supports_crowdsec,omitempty"` + + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports + SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + // TargetCluster The proxy cluster this domain is validated against (only for custom domains) TargetCluster *string `json:"target_cluster,omitempty"` @@ -2593,12 +3916,15 @@ type RulePortRange struct { Start int `json:"start"` } -// ScimIntegration Represents a SCIM IDP integration +// ScimIntegration defines model for ScimIntegration. type ScimIntegration struct { // AuthToken SCIM API token (full on creation, masked otherwise) AuthToken string `json:"auth_token"` - // Enabled Indicates whether the integration is enabled + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled Enabled bool `json:"enabled"` // GroupPrefixes List of start_with string patterns for groups to sync @@ -2610,6 +3936,9 @@ type ScimIntegration struct { // LastSyncedAt Timestamp of when the integration was last synced LastSyncedAt time.Time `json:"last_synced_at"` + // Prefix The connection prefix used for the SCIM provider + Prefix string `json:"prefix"` + // Provider Name of the SCIM identity provider Provider string `json:"provider"` @@ -2655,7 +3984,9 @@ type SentinelOneMatchAttributesNetworkStatus string // Service defines model for Service. type Service struct { - Auth ServiceAuthConfig `json:"auth"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` + Auth ServiceAuthConfig `json:"auth"` // Domain Domain for the service Domain string `json:"domain"` @@ -2664,8 +3995,14 @@ type Service struct { Enabled bool `json:"enabled"` // Id Service ID - Id string `json:"id"` - Meta ServiceMeta `json:"meta"` + Id string `json:"id"` + + // ListenPort Port the proxy listens on (L4/TLS only) + ListenPort *int `json:"listen_port,omitempty"` + Meta ServiceMeta `json:"meta"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceMode `json:"mode,omitempty"` // Name Service name Name string `json:"name"` @@ -2673,6 +4010,9 @@ type Service struct { // PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address PassHostHeader *bool `json:"pass_host_header,omitempty"` + // PortAutoAssigned Whether the listen port was auto-assigned + PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"` + // ProxyCluster The proxy cluster handling this service (derived from domain) ProxyCluster *string `json:"proxy_cluster,omitempty"` @@ -2681,11 +4021,18 @@ type Service struct { // Targets List of target backends for this service Targets []ServiceTarget `json:"targets"` + + // Terminated Whether the service has been terminated. Terminated services cannot be updated. Services that violate the Terms of Service will be terminated. + Terminated *bool `json:"terminated,omitempty"` } +// ServiceMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceMode string + // ServiceAuthConfig defines model for ServiceAuthConfig. type ServiceAuthConfig struct { BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` + HeaderAuths *[]HeaderAuthConfig `json:"header_auths,omitempty"` LinkAuth *LinkAuthConfig `json:"link_auth,omitempty"` PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty"` PinAuth *PINAuthConfig `json:"pin_auth,omitempty"` @@ -2708,7 +4055,9 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { - Auth ServiceAuthConfig `json:"auth"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` + Auth *ServiceAuthConfig `json:"auth,omitempty"` // Domain Domain for the service Domain string `json:"domain"` @@ -2716,6 +4065,12 @@ type ServiceRequest struct { // Enabled Whether the service is enabled Enabled bool `json:"enabled"` + // ListenPort Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + ListenPort *int `json:"listen_port,omitempty"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceRequestMode `json:"mode,omitempty"` + // Name Service name Name string `json:"name"` @@ -2726,21 +4081,25 @@ type ServiceRequest struct { RewriteRedirects *bool `json:"rewrite_redirects,omitempty"` // Targets List of target backends for this service - Targets []ServiceTarget `json:"targets"` + Targets *[]ServiceTarget `json:"targets,omitempty"` } +// ServiceRequestMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceRequestMode string + // ServiceTarget defines model for ServiceTarget. type ServiceTarget struct { // Enabled Whether this target is enabled Enabled bool `json:"enabled"` // Host Backend ip or domain for this target - Host *string `json:"host,omitempty"` + Host *string `json:"host,omitempty"` + Options *ServiceTargetOptions `json:"options,omitempty"` - // Path URL path prefix for this target + // Path URL path prefix for this target (HTTP only) Path *string `json:"path,omitempty"` - // Port Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + // Port Backend port for this target Port int `json:"port"` // Protocol Protocol to use when connecting to the backend @@ -2749,16 +4108,40 @@ type ServiceTarget struct { // TargetId Target ID TargetId string `json:"target_id"` - // TargetType Target type (e.g., "peer", "resource") + // TargetType Target type TargetType ServiceTargetTargetType `json:"target_type"` } // ServiceTargetProtocol Protocol to use when connecting to the backend type ServiceTargetProtocol string -// ServiceTargetTargetType Target type (e.g., "peer", "resource") +// ServiceTargetTargetType Target type type ServiceTargetTargetType string +// ServiceTargetOptions defines model for ServiceTargetOptions. +type ServiceTargetOptions struct { + // CustomHeaders Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected. + CustomHeaders *map[string]string `json:"custom_headers,omitempty"` + + // PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. + PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"` + + // ProxyProtocol Send PROXY Protocol v2 header to this backend (TCP/TLS only) + ProxyProtocol *bool `json:"proxy_protocol,omitempty"` + + // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") + RequestTimeout *string `json:"request_timeout,omitempty"` + + // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). + SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` + + // SkipTlsVerify Skip TLS certificate verification for this backend + SkipTlsVerify *bool `json:"skip_tls_verify,omitempty"` +} + +// ServiceTargetOptionsPathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. +type ServiceTargetOptionsPathRewrite string + // SetupKey defines model for SetupKey. type SetupKey struct { // AllowExtraDnsLabels Allow extra DNS labels to be added to the peer @@ -2960,6 +4343,11 @@ type Subscription struct { UpdatedAt time.Time `json:"updated_at"` } +// SyncResult Response for a manual sync trigger +type SyncResult struct { + Result *string `json:"result,omitempty"` +} + // TenantGroupResponse defines model for TenantGroupResponse. type TenantGroupResponse struct { // Id The Group ID @@ -3005,14 +4393,86 @@ type TenantResponse struct { // TenantResponseStatus The status of the tenant type TenantResponseStatus string -// UpdateScimIntegrationRequest Request payload for updating an SCIM IDP integration -type UpdateScimIntegrationRequest struct { - // Enabled Indicates whether the integration is enabled +// UpdateAzureIntegrationRequest defines model for UpdateAzureIntegrationRequest. +type UpdateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId *string `json:"client_id,omitempty"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret *string `json:"client_secret,omitempty"` + + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled Enabled *bool `json:"enabled,omitempty"` // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId *string `json:"tenant_id,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateGoogleIntegrationRequest defines model for UpdateGoogleIntegrationRequest. +type UpdateGoogleIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId *string `json:"customer_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey *string `json:"service_account_key,omitempty"` + + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateOktaScimIntegrationRequest defines model for UpdateOktaScimIntegrationRequest. +type UpdateOktaScimIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateScimIntegrationRequest defines model for UpdateScimIntegrationRequest. +type UpdateScimIntegrationRequest struct { + // ConnectorId DEX connector ID for embedded IDP setups + ConnectorId *string `json:"connector_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // Prefix The connection prefix used for the SCIM provider + Prefix *string `json:"prefix,omitempty"` + // UserGroupPrefixes List of start_with string patterns for groups which users to sync UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` } @@ -3220,6 +4680,16 @@ type UserRequest struct { Role string `json:"role"` } +// WebhookTarget Target configuration for webhook notification channels. +type WebhookTarget struct { + // Headers Custom HTTP headers sent with each webhook request. + // Values are write-only; in GET responses all values are masked. + Headers *map[string]string `json:"headers,omitempty"` + + // Url The webhook endpoint URL to send notifications to. + Url string `json:"url"` +} + // WorkloadRequest defines model for WorkloadRequest. type WorkloadRequest struct { union json.RawMessage @@ -3276,6 +4746,9 @@ type ZoneRequest struct { Name string `json:"name"` } +// Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. +type Conflict = ErrorResponse + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number @@ -3329,6 +4802,12 @@ type GetApiEventsProxyParams struct { // PageSize Number of items per page (max 100) PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"` + // SortBy Field to sort by (url sorts by host then path) + SortBy *GetApiEventsProxyParamsSortBy `form:"sort_by,omitempty" json:"sort_by,omitempty"` + + // SortOrder Sort order (ascending or descending) + SortOrder *GetApiEventsProxyParamsSortOrder `form:"sort_order,omitempty" json:"sort_order,omitempty"` + // Search General search across request ID, host, path, source IP, user email, and user name Search *string `form:"search,omitempty" json:"search,omitempty"` @@ -3366,6 +4845,12 @@ type GetApiEventsProxyParams struct { EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` } +// GetApiEventsProxyParamsSortBy defines parameters for GetApiEventsProxy. +type GetApiEventsProxyParamsSortBy string + +// GetApiEventsProxyParamsSortOrder defines parameters for GetApiEventsProxy. +type GetApiEventsProxyParamsSortOrder string + // GetApiEventsProxyParamsMethod defines parameters for GetApiEventsProxy. type GetApiEventsProxyParamsMethod string @@ -3507,6 +4992,12 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest // PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType. type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest +// CreateAzureIntegrationJSONRequestBody defines body for CreateAzureIntegration for application/json ContentType. +type CreateAzureIntegrationJSONRequestBody = CreateAzureIntegrationRequest + +// UpdateAzureIntegrationJSONRequestBody defines body for UpdateAzureIntegration for application/json ContentType. +type UpdateAzureIntegrationJSONRequestBody = UpdateAzureIntegrationRequest + // PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceActivate for application/json ContentType. type PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody @@ -3525,6 +5016,12 @@ type CreateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest // UpdateFalconEDRIntegrationJSONRequestBody defines body for UpdateFalconEDRIntegration for application/json ContentType. type UpdateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest +// CreateFleetDMEDRIntegrationJSONRequestBody defines body for CreateFleetDMEDRIntegration for application/json ContentType. +type CreateFleetDMEDRIntegrationJSONRequestBody = EDRFleetDMRequest + +// UpdateFleetDMEDRIntegrationJSONRequestBody defines body for UpdateFleetDMEDRIntegration for application/json ContentType. +type UpdateFleetDMEDRIntegrationJSONRequestBody = EDRFleetDMRequest + // CreateHuntressEDRIntegrationJSONRequestBody defines body for CreateHuntressEDRIntegration for application/json ContentType. type CreateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest @@ -3543,6 +5040,12 @@ type CreateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest // UpdateSentinelOneEDRIntegrationJSONRequestBody defines body for UpdateSentinelOneEDRIntegration for application/json ContentType. type UpdateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest +// CreateGoogleIntegrationJSONRequestBody defines body for CreateGoogleIntegration for application/json ContentType. +type CreateGoogleIntegrationJSONRequestBody = CreateGoogleIntegrationRequest + +// UpdateGoogleIntegrationJSONRequestBody defines body for UpdateGoogleIntegration for application/json ContentType. +type UpdateGoogleIntegrationJSONRequestBody = UpdateGoogleIntegrationRequest + // PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType. type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest @@ -3558,6 +5061,18 @@ type PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody PostApiIntegrati // PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdUnlink for application/json ContentType. type PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody PostApiIntegrationsMspTenantsIdUnlinkJSONBody +// CreateNotificationChannelJSONRequestBody defines body for CreateNotificationChannel for application/json ContentType. +type CreateNotificationChannelJSONRequestBody = NotificationChannelRequest + +// UpdateNotificationChannelJSONRequestBody defines body for UpdateNotificationChannel for application/json ContentType. +type UpdateNotificationChannelJSONRequestBody = NotificationChannelRequest + +// CreateOktaScimIntegrationJSONRequestBody defines body for CreateOktaScimIntegration for application/json ContentType. +type CreateOktaScimIntegrationJSONRequestBody = CreateOktaScimIntegrationRequest + +// UpdateOktaScimIntegrationJSONRequestBody defines body for UpdateOktaScimIntegration for application/json ContentType. +type UpdateOktaScimIntegrationJSONRequestBody = UpdateOktaScimIntegrationRequest + // CreateSCIMIntegrationJSONRequestBody defines body for CreateSCIMIntegration for application/json ContentType. type CreateSCIMIntegrationJSONRequestBody = CreateScimIntegrationRequest @@ -3654,6 +5169,130 @@ type PutApiUsersUserIdPasswordJSONRequestBody = PasswordChangeRequest // PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest +// AsEmailTarget returns the union data inside the NotificationChannelRequest_Target as a EmailTarget +func (t NotificationChannelRequest_Target) AsEmailTarget() (EmailTarget, error) { + var body EmailTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromEmailTarget overwrites any union data inside the NotificationChannelRequest_Target as the provided EmailTarget +func (t *NotificationChannelRequest_Target) FromEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeEmailTarget performs a merge with any union data inside the NotificationChannelRequest_Target, using the provided EmailTarget +func (t *NotificationChannelRequest_Target) MergeEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWebhookTarget returns the union data inside the NotificationChannelRequest_Target as a WebhookTarget +func (t NotificationChannelRequest_Target) AsWebhookTarget() (WebhookTarget, error) { + var body WebhookTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWebhookTarget overwrites any union data inside the NotificationChannelRequest_Target as the provided WebhookTarget +func (t *NotificationChannelRequest_Target) FromWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWebhookTarget performs a merge with any union data inside the NotificationChannelRequest_Target, using the provided WebhookTarget +func (t *NotificationChannelRequest_Target) MergeWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t NotificationChannelRequest_Target) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *NotificationChannelRequest_Target) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsEmailTarget returns the union data inside the NotificationChannelResponse_Target as a EmailTarget +func (t NotificationChannelResponse_Target) AsEmailTarget() (EmailTarget, error) { + var body EmailTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromEmailTarget overwrites any union data inside the NotificationChannelResponse_Target as the provided EmailTarget +func (t *NotificationChannelResponse_Target) FromEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeEmailTarget performs a merge with any union data inside the NotificationChannelResponse_Target, using the provided EmailTarget +func (t *NotificationChannelResponse_Target) MergeEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWebhookTarget returns the union data inside the NotificationChannelResponse_Target as a WebhookTarget +func (t NotificationChannelResponse_Target) AsWebhookTarget() (WebhookTarget, error) { + var body WebhookTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWebhookTarget overwrites any union data inside the NotificationChannelResponse_Target as the provided WebhookTarget +func (t *NotificationChannelResponse_Target) FromWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWebhookTarget performs a merge with any union data inside the NotificationChannelResponse_Target, using the provided WebhookTarget +func (t *NotificationChannelResponse_Target) MergeWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t NotificationChannelResponse_Target) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *NotificationChannelResponse_Target) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + // AsBundleWorkloadRequest returns the union data inside the WorkloadRequest as a BundleWorkloadRequest func (t WorkloadRequest) AsBundleWorkloadRequest() (BundleWorkloadRequest, error) { var body BundleWorkloadRequest diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 44838fc16..604f9c793 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v7.34.1 // source: management.proto package proto @@ -221,6 +221,61 @@ func (RuleAction) EnumDescriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{3} } +type ExposeProtocol int32 + +const ( + ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0 + ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 + ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 + ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 +) + +// Enum value maps for ExposeProtocol. +var ( + ExposeProtocol_name = map[int32]string{ + 0: "EXPOSE_HTTP", + 1: "EXPOSE_HTTPS", + 2: "EXPOSE_TCP", + 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", + } + ExposeProtocol_value = map[string]int32{ + "EXPOSE_HTTP": 0, + "EXPOSE_HTTPS": 1, + "EXPOSE_TCP": 2, + "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, + } +) + +func (x ExposeProtocol) Enum() *ExposeProtocol { + p := new(ExposeProtocol) + *p = x + return p +} + +func (x ExposeProtocol) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[4].Descriptor() +} + +func (ExposeProtocol) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[4] +} + +func (x ExposeProtocol) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use ExposeProtocol.Descriptor instead. +func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{4} +} + type HostConfig_Protocol int32 const ( @@ -260,11 +315,11 @@ func (x HostConfig_Protocol) String() string { } func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[4].Descriptor() + return file_management_proto_enumTypes[5].Descriptor() } func (HostConfig_Protocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[4] + return &file_management_proto_enumTypes[5] } func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { @@ -303,11 +358,11 @@ func (x DeviceAuthorizationFlowProvider) String() string { } func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[5].Descriptor() + return file_management_proto_enumTypes[6].Descriptor() } func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[5] + return &file_management_proto_enumTypes[6] } func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { @@ -2204,8 +2259,8 @@ type AutoUpdateSettings struct { unknownFields protoimpl.UnknownFields Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` - // alwaysUpdate = true → Updates happen automatically in the background - // alwaysUpdate = false → Updates only happen when triggered by a peer connection + // alwaysUpdate = true → Updates are installed automatically in the background + // alwaysUpdate = false → Updates require user interaction from the UI AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"` } @@ -2873,7 +2928,9 @@ type ProviderConfig struct { // An IDP application client id ClientID string `protobuf:"bytes,1,opt,name=ClientID,proto3" json:"ClientID,omitempty"` - // An IDP application client secret + // Deprecated: use embedded IdP for providers that require a client secret (e.g. Google Workspace). + // + // Deprecated: Do not use. ClientSecret string `protobuf:"bytes,2,opt,name=ClientSecret,proto3" json:"ClientSecret,omitempty"` // An IDP API domain // Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint @@ -2937,6 +2994,7 @@ func (x *ProviderConfig) GetClientID() string { return "" } +// Deprecated: Do not use. func (x *ProviderConfig) GetClientSecret() string { if x != nil { return x.ClientSecret @@ -3983,6 +4041,350 @@ func (x *ForwardingRule) GetTranslatedPort() *PortInfo { return nil } +type ExposeServiceRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"` + Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=management.ExposeProtocol" json:"protocol,omitempty"` + Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"` + Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"` + UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` + Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` + NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` +} + +func (x *ExposeServiceRequest) Reset() { + *x = ExposeServiceRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[47] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ExposeServiceRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExposeServiceRequest) ProtoMessage() {} + +func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[47] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. +func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{47} +} + +func (x *ExposeServiceRequest) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol { + if x != nil { + return x.Protocol + } + return ExposeProtocol_EXPOSE_HTTP +} + +func (x *ExposeServiceRequest) GetPin() string { + if x != nil { + return x.Pin + } + return "" +} + +func (x *ExposeServiceRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *ExposeServiceRequest) GetUserGroups() []string { + if x != nil { + return x.UserGroups + } + return nil +} + +func (x *ExposeServiceRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *ExposeServiceRequest) GetNamePrefix() string { + if x != nil { + return x.NamePrefix + } + return "" +} + +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + +type ExposeServiceResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` +} + +func (x *ExposeServiceResponse) Reset() { + *x = ExposeServiceResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[48] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ExposeServiceResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExposeServiceResponse) ProtoMessage() {} + +func (x *ExposeServiceResponse) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[48] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExposeServiceResponse.ProtoReflect.Descriptor instead. +func (*ExposeServiceResponse) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{48} +} + +func (x *ExposeServiceResponse) GetServiceName() string { + if x != nil { + return x.ServiceName + } + return "" +} + +func (x *ExposeServiceResponse) GetServiceUrl() string { + if x != nil { + return x.ServiceUrl + } + return "" +} + +func (x *ExposeServiceResponse) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *ExposeServiceResponse) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + +type RenewExposeRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` +} + +func (x *RenewExposeRequest) Reset() { + *x = RenewExposeRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[49] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RenewExposeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenewExposeRequest) ProtoMessage() {} + +func (x *RenewExposeRequest) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[49] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenewExposeRequest.ProtoReflect.Descriptor instead. +func (*RenewExposeRequest) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{49} +} + +func (x *RenewExposeRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +type RenewExposeResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *RenewExposeResponse) Reset() { + *x = RenewExposeResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[50] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RenewExposeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenewExposeResponse) ProtoMessage() {} + +func (x *RenewExposeResponse) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[50] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenewExposeResponse.ProtoReflect.Descriptor instead. +func (*RenewExposeResponse) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{50} +} + +type StopExposeRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` +} + +func (x *StopExposeRequest) Reset() { + *x = StopExposeRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[51] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StopExposeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopExposeRequest) ProtoMessage() {} + +func (x *StopExposeRequest) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[51] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopExposeRequest.ProtoReflect.Descriptor instead. +func (*StopExposeRequest) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{51} +} + +func (x *StopExposeRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +type StopExposeResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *StopExposeResponse) Reset() { + *x = StopExposeResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[52] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StopExposeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopExposeResponse) ProtoMessage() {} + +func (x *StopExposeResponse) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[52] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopExposeResponse.ProtoReflect.Descriptor instead. +func (*StopExposeResponse) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{52} +} + type PortInfo_Range struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3995,7 +4397,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[48] + mi := &file_management_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4008,7 +4410,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[48] + mi := &file_management_proto_msgTypes[54] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4448,230 +4850,287 @@ var file_management_proto_rawDesc = []byte{ 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, - 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, - 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, - 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, - 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, - 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, - 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, - 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, - 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, - 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, - 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, - 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, - 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, - 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, - 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, - 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, - 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, - 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x4e, - 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x22, - 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, - 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, - 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, - 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, - 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, - 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, - 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, - 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, - 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, - 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, - 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, - 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, - 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, - 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, - 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, - 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, - 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, - 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, - 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, - 0x74, 0x2a, 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, - 0x0a, 0x0e, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, - 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, - 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, - 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, - 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, - 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, - 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, - 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, - 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, - 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, - 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, - 0x01, 0x32, 0x96, 0x05, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, + 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, + 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, + 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, + 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, + 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, + 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, + 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, + 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, + 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, + 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, + 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, + 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, + 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, + 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, + 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, + 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, + 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, + 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, + 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, + 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, + 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, + 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, + 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, + 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, + 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, + 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, + 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, + 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, + 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, + 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, + 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, + 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, + 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, + 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, + 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, + 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, + 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, + 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, + 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x72, + 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x75, + 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, + 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, + 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, + 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, + 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, + 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, + 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, + 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, + 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x0a, 0x11, + 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, 0x0a, 0x12, 0x53, 0x74, 0x6f, + 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, + 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x0e, + 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x10, 0x00, + 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, 0x01, 0x12, + 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, + 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, + 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, + 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, + 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, + 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, + 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, + 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, + 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, + 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, + 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, + 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, + 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, + 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, - 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, - 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, - 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, + 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x4c, 0x0a, + 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, - 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, - 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0b, 0x52, + 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, 0x74, 0x6f, 0x70, + 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4686,152 +5145,166 @@ func file_management_proto_rawDescGZIP() []byte { return file_management_proto_rawDescData } -var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 6) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 49) +var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 7) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 55) var file_management_proto_goTypes = []interface{}{ (JobStatus)(0), // 0: management.JobStatus (RuleProtocol)(0), // 1: management.RuleProtocol (RuleDirection)(0), // 2: management.RuleDirection (RuleAction)(0), // 3: management.RuleAction - (HostConfig_Protocol)(0), // 4: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 5: management.DeviceAuthorizationFlow.provider - (*EncryptedMessage)(nil), // 6: management.EncryptedMessage - (*JobRequest)(nil), // 7: management.JobRequest - (*JobResponse)(nil), // 8: management.JobResponse - (*BundleParameters)(nil), // 9: management.BundleParameters - (*BundleResult)(nil), // 10: management.BundleResult - (*SyncRequest)(nil), // 11: management.SyncRequest - (*SyncResponse)(nil), // 12: management.SyncResponse - (*SyncMetaRequest)(nil), // 13: management.SyncMetaRequest - (*LoginRequest)(nil), // 14: management.LoginRequest - (*PeerKeys)(nil), // 15: management.PeerKeys - (*Environment)(nil), // 16: management.Environment - (*File)(nil), // 17: management.File - (*Flags)(nil), // 18: management.Flags - (*PeerSystemMeta)(nil), // 19: management.PeerSystemMeta - (*LoginResponse)(nil), // 20: management.LoginResponse - (*ServerKeyResponse)(nil), // 21: management.ServerKeyResponse - (*Empty)(nil), // 22: management.Empty - (*NetbirdConfig)(nil), // 23: management.NetbirdConfig - (*HostConfig)(nil), // 24: management.HostConfig - (*RelayConfig)(nil), // 25: management.RelayConfig - (*FlowConfig)(nil), // 26: management.FlowConfig - (*JWTConfig)(nil), // 27: management.JWTConfig - (*ProtectedHostConfig)(nil), // 28: management.ProtectedHostConfig - (*PeerConfig)(nil), // 29: management.PeerConfig - (*AutoUpdateSettings)(nil), // 30: management.AutoUpdateSettings - (*NetworkMap)(nil), // 31: management.NetworkMap - (*SSHAuth)(nil), // 32: management.SSHAuth - (*MachineUserIndexes)(nil), // 33: management.MachineUserIndexes - (*RemotePeerConfig)(nil), // 34: management.RemotePeerConfig - (*SSHConfig)(nil), // 35: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 36: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 37: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 38: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 39: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 40: management.ProviderConfig - (*Route)(nil), // 41: management.Route - (*DNSConfig)(nil), // 42: management.DNSConfig - (*CustomZone)(nil), // 43: management.CustomZone - (*SimpleRecord)(nil), // 44: management.SimpleRecord - (*NameServerGroup)(nil), // 45: management.NameServerGroup - (*NameServer)(nil), // 46: management.NameServer - (*FirewallRule)(nil), // 47: management.FirewallRule - (*NetworkAddress)(nil), // 48: management.NetworkAddress - (*Checks)(nil), // 49: management.Checks - (*PortInfo)(nil), // 50: management.PortInfo - (*RouteFirewallRule)(nil), // 51: management.RouteFirewallRule - (*ForwardingRule)(nil), // 52: management.ForwardingRule - nil, // 53: management.SSHAuth.MachineUsersEntry - (*PortInfo_Range)(nil), // 54: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 55: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 56: google.protobuf.Duration + (ExposeProtocol)(0), // 4: management.ExposeProtocol + (HostConfig_Protocol)(0), // 5: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 6: management.DeviceAuthorizationFlow.provider + (*EncryptedMessage)(nil), // 7: management.EncryptedMessage + (*JobRequest)(nil), // 8: management.JobRequest + (*JobResponse)(nil), // 9: management.JobResponse + (*BundleParameters)(nil), // 10: management.BundleParameters + (*BundleResult)(nil), // 11: management.BundleResult + (*SyncRequest)(nil), // 12: management.SyncRequest + (*SyncResponse)(nil), // 13: management.SyncResponse + (*SyncMetaRequest)(nil), // 14: management.SyncMetaRequest + (*LoginRequest)(nil), // 15: management.LoginRequest + (*PeerKeys)(nil), // 16: management.PeerKeys + (*Environment)(nil), // 17: management.Environment + (*File)(nil), // 18: management.File + (*Flags)(nil), // 19: management.Flags + (*PeerSystemMeta)(nil), // 20: management.PeerSystemMeta + (*LoginResponse)(nil), // 21: management.LoginResponse + (*ServerKeyResponse)(nil), // 22: management.ServerKeyResponse + (*Empty)(nil), // 23: management.Empty + (*NetbirdConfig)(nil), // 24: management.NetbirdConfig + (*HostConfig)(nil), // 25: management.HostConfig + (*RelayConfig)(nil), // 26: management.RelayConfig + (*FlowConfig)(nil), // 27: management.FlowConfig + (*JWTConfig)(nil), // 28: management.JWTConfig + (*ProtectedHostConfig)(nil), // 29: management.ProtectedHostConfig + (*PeerConfig)(nil), // 30: management.PeerConfig + (*AutoUpdateSettings)(nil), // 31: management.AutoUpdateSettings + (*NetworkMap)(nil), // 32: management.NetworkMap + (*SSHAuth)(nil), // 33: management.SSHAuth + (*MachineUserIndexes)(nil), // 34: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 35: management.RemotePeerConfig + (*SSHConfig)(nil), // 36: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 37: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 38: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 39: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 40: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 41: management.ProviderConfig + (*Route)(nil), // 42: management.Route + (*DNSConfig)(nil), // 43: management.DNSConfig + (*CustomZone)(nil), // 44: management.CustomZone + (*SimpleRecord)(nil), // 45: management.SimpleRecord + (*NameServerGroup)(nil), // 46: management.NameServerGroup + (*NameServer)(nil), // 47: management.NameServer + (*FirewallRule)(nil), // 48: management.FirewallRule + (*NetworkAddress)(nil), // 49: management.NetworkAddress + (*Checks)(nil), // 50: management.Checks + (*PortInfo)(nil), // 51: management.PortInfo + (*RouteFirewallRule)(nil), // 52: management.RouteFirewallRule + (*ForwardingRule)(nil), // 53: management.ForwardingRule + (*ExposeServiceRequest)(nil), // 54: management.ExposeServiceRequest + (*ExposeServiceResponse)(nil), // 55: management.ExposeServiceResponse + (*RenewExposeRequest)(nil), // 56: management.RenewExposeRequest + (*RenewExposeResponse)(nil), // 57: management.RenewExposeResponse + (*StopExposeRequest)(nil), // 58: management.StopExposeRequest + (*StopExposeResponse)(nil), // 59: management.StopExposeResponse + nil, // 60: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 61: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 62: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 63: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ - 9, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters + 10, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters 0, // 1: management.JobResponse.status:type_name -> management.JobStatus - 10, // 2: management.JobResponse.bundle:type_name -> management.BundleResult - 19, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta - 23, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 29, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 34, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 31, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 49, // 8: management.SyncResponse.Checks:type_name -> management.Checks - 19, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta - 19, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta - 15, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 48, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress - 16, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment - 17, // 14: management.PeerSystemMeta.files:type_name -> management.File - 18, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags - 23, // 16: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 29, // 17: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 49, // 18: management.LoginResponse.Checks:type_name -> management.Checks - 55, // 19: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp - 24, // 20: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 28, // 21: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig - 24, // 22: management.NetbirdConfig.signal:type_name -> management.HostConfig - 25, // 23: management.NetbirdConfig.relay:type_name -> management.RelayConfig - 26, // 24: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 4, // 25: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 56, // 26: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 24, // 27: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 35, // 28: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 30, // 29: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings - 29, // 30: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 34, // 31: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 41, // 32: management.NetworkMap.Routes:type_name -> management.Route - 42, // 33: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 34, // 34: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 47, // 35: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 51, // 36: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 52, // 37: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 32, // 38: management.NetworkMap.sshAuth:type_name -> management.SSHAuth - 53, // 39: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry - 35, // 40: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 27, // 41: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 5, // 42: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 40, // 43: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 40, // 44: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 45, // 45: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 43, // 46: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 44, // 47: management.CustomZone.Records:type_name -> management.SimpleRecord - 46, // 48: management.NameServerGroup.NameServers:type_name -> management.NameServer + 11, // 2: management.JobResponse.bundle:type_name -> management.BundleResult + 20, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta + 24, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig + 30, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 35, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 32, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 50, // 8: management.SyncResponse.Checks:type_name -> management.Checks + 20, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta + 20, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta + 16, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys + 49, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 17, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment + 18, // 14: management.PeerSystemMeta.files:type_name -> management.File + 19, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags + 24, // 16: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig + 30, // 17: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 50, // 18: management.LoginResponse.Checks:type_name -> management.Checks + 62, // 19: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 25, // 20: management.NetbirdConfig.stuns:type_name -> management.HostConfig + 29, // 21: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 25, // 22: management.NetbirdConfig.signal:type_name -> management.HostConfig + 26, // 23: management.NetbirdConfig.relay:type_name -> management.RelayConfig + 27, // 24: management.NetbirdConfig.flow:type_name -> management.FlowConfig + 5, // 25: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 63, // 26: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 25, // 27: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 36, // 28: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 31, // 29: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 30, // 30: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 35, // 31: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 42, // 32: management.NetworkMap.Routes:type_name -> management.Route + 43, // 33: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 35, // 34: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 48, // 35: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 52, // 36: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 53, // 37: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 33, // 38: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 60, // 39: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 36, // 40: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 28, // 41: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 6, // 42: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 41, // 43: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 41, // 44: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 46, // 45: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 44, // 46: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 45, // 47: management.CustomZone.Records:type_name -> management.SimpleRecord + 47, // 48: management.NameServerGroup.NameServers:type_name -> management.NameServer 2, // 49: management.FirewallRule.Direction:type_name -> management.RuleDirection 3, // 50: management.FirewallRule.Action:type_name -> management.RuleAction 1, // 51: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 50, // 52: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 54, // 53: management.PortInfo.range:type_name -> management.PortInfo.Range + 51, // 52: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 61, // 53: management.PortInfo.range:type_name -> management.PortInfo.Range 3, // 54: management.RouteFirewallRule.action:type_name -> management.RuleAction 1, // 55: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 50, // 56: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 51, // 56: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo 1, // 57: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 50, // 58: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 50, // 59: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 33, // 60: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes - 6, // 61: management.ManagementService.Login:input_type -> management.EncryptedMessage - 6, // 62: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 22, // 63: management.ManagementService.GetServerKey:input_type -> management.Empty - 22, // 64: management.ManagementService.isHealthy:input_type -> management.Empty - 6, // 65: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 6, // 66: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 6, // 67: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 6, // 68: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 6, // 69: management.ManagementService.Job:input_type -> management.EncryptedMessage - 6, // 70: management.ManagementService.Login:output_type -> management.EncryptedMessage - 6, // 71: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 21, // 72: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 22, // 73: management.ManagementService.isHealthy:output_type -> management.Empty - 6, // 74: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 6, // 75: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 22, // 76: management.ManagementService.SyncMeta:output_type -> management.Empty - 22, // 77: management.ManagementService.Logout:output_type -> management.Empty - 6, // 78: management.ManagementService.Job:output_type -> management.EncryptedMessage - 70, // [70:79] is the sub-list for method output_type - 61, // [61:70] is the sub-list for method input_type - 61, // [61:61] is the sub-list for extension type_name - 61, // [61:61] is the sub-list for extension extendee - 0, // [0:61] is the sub-list for field type_name + 51, // 58: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 51, // 59: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 4, // 60: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol + 34, // 61: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 7, // 62: management.ManagementService.Login:input_type -> management.EncryptedMessage + 7, // 63: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 23, // 64: management.ManagementService.GetServerKey:input_type -> management.Empty + 23, // 65: management.ManagementService.isHealthy:input_type -> management.Empty + 7, // 66: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 7, // 67: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 7, // 68: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 7, // 69: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 7, // 70: management.ManagementService.Job:input_type -> management.EncryptedMessage + 7, // 71: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage + 7, // 72: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage + 7, // 73: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage + 7, // 74: management.ManagementService.Login:output_type -> management.EncryptedMessage + 7, // 75: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 22, // 76: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 23, // 77: management.ManagementService.isHealthy:output_type -> management.Empty + 7, // 78: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 7, // 79: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 23, // 80: management.ManagementService.SyncMeta:output_type -> management.Empty + 23, // 81: management.ManagementService.Logout:output_type -> management.Empty + 7, // 82: management.ManagementService.Job:output_type -> management.EncryptedMessage + 7, // 83: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage + 7, // 84: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage + 7, // 85: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage + 74, // [74:86] is the sub-list for method output_type + 62, // [62:74] is the sub-list for method input_type + 62, // [62:62] is the sub-list for extension type_name + 62, // [62:62] is the sub-list for extension extendee + 0, // [0:62] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -5404,7 +5877,79 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[47].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ExposeServiceRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } file_management_proto_msgTypes[48].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ExposeServiceResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[49].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RenewExposeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[50].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RenewExposeResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[51].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StopExposeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[52].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StopExposeResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[54].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -5432,8 +5977,8 @@ func file_management_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, - NumEnums: 6, - NumMessages: 49, + NumEnums: 7, + NumMessages: 55, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index d97d66819..70a530679 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -51,6 +51,15 @@ service ManagementService { // Executes a job on a target peer (e.g., debug bundle) rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {} + + // CreateExpose creates a temporary reverse proxy service for a peer + rpc CreateExpose(EncryptedMessage) returns (EncryptedMessage) {} + + // RenewExpose extends the TTL of an active expose session + rpc RenewExpose(EncryptedMessage) returns (EncryptedMessage) {} + + // StopExpose terminates an active expose session + rpc StopExpose(EncryptedMessage) returns (EncryptedMessage) {} } message EncryptedMessage { @@ -331,8 +340,8 @@ message PeerConfig { message AutoUpdateSettings { string version = 1; /* - alwaysUpdate = true → Updates happen automatically in the background - alwaysUpdate = false → Updates only happen when triggered by a peer connection + alwaysUpdate = true → Updates are installed automatically in the background + alwaysUpdate = false → Updates require user interaction from the UI */ bool alwaysUpdate = 2; } @@ -455,8 +464,8 @@ message PKCEAuthorizationFlow { message ProviderConfig { // An IDP application client id string ClientID = 1; - // An IDP application client secret - string ClientSecret = 2; + // Deprecated: use embedded IdP for providers that require a client secret (e.g. Google Workspace). + string ClientSecret = 2 [deprecated = true]; // An IDP API domain // Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint string Domain = 3; @@ -637,3 +646,41 @@ message ForwardingRule { // Translated port information, where the traffic should be forwarded to PortInfo translatedPort = 4; } + +enum ExposeProtocol { + EXPOSE_HTTP = 0; + EXPOSE_HTTPS = 1; + EXPOSE_TCP = 2; + EXPOSE_UDP = 3; + EXPOSE_TLS = 4; +} + +message ExposeServiceRequest { + uint32 port = 1; + ExposeProtocol protocol = 2; + string pin = 3; + string password = 4; + repeated string user_groups = 5; + string domain = 6; + string name_prefix = 7; + uint32 listen_port = 8; +} + +message ExposeServiceResponse { + string service_name = 1; + string service_url = 2; + string domain = 3; + bool port_auto_assigned = 4; +} + +message RenewExposeRequest { + string domain = 1; +} + +message RenewExposeResponse {} + +message StopExposeRequest { + string domain = 1; +} + +message StopExposeResponse {} diff --git a/shared/management/proto/management_grpc.pb.go b/shared/management/proto/management_grpc.pb.go index b78e21aaa..39a342041 100644 --- a/shared/management/proto/management_grpc.pb.go +++ b/shared/management/proto/management_grpc.pb.go @@ -52,6 +52,12 @@ type ManagementServiceClient interface { Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) // Executes a job on a target peer (e.g., debug bundle) Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error) + // CreateExpose creates a temporary reverse proxy service for a peer + CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) + // RenewExpose extends the TTL of an active expose session + RenewExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) + // StopExpose terminates an active expose session + StopExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) } type managementServiceClient struct { @@ -188,6 +194,33 @@ func (x *managementServiceJobClient) Recv() (*EncryptedMessage, error) { return m, nil } +func (c *managementServiceClient) CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) { + out := new(EncryptedMessage) + err := c.cc.Invoke(ctx, "/management.ManagementService/CreateExpose", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *managementServiceClient) RenewExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) { + out := new(EncryptedMessage) + err := c.cc.Invoke(ctx, "/management.ManagementService/RenewExpose", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *managementServiceClient) StopExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) { + out := new(EncryptedMessage) + err := c.cc.Invoke(ctx, "/management.ManagementService/StopExpose", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ManagementServiceServer is the server API for ManagementService service. // All implementations must embed UnimplementedManagementServiceServer // for forward compatibility @@ -226,6 +259,12 @@ type ManagementServiceServer interface { Logout(context.Context, *EncryptedMessage) (*Empty, error) // Executes a job on a target peer (e.g., debug bundle) Job(ManagementService_JobServer) error + // CreateExpose creates a temporary reverse proxy service for a peer + CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) + // RenewExpose extends the TTL of an active expose session + RenewExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) + // StopExpose terminates an active expose session + StopExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) mustEmbedUnimplementedManagementServiceServer() } @@ -260,6 +299,15 @@ func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMe func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error { return status.Errorf(codes.Unimplemented, "method Job not implemented") } +func (UnimplementedManagementServiceServer) CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateExpose not implemented") +} +func (UnimplementedManagementServiceServer) RenewExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) { + return nil, status.Errorf(codes.Unimplemented, "method RenewExpose not implemented") +} +func (UnimplementedManagementServiceServer) StopExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) { + return nil, status.Errorf(codes.Unimplemented, "method StopExpose not implemented") +} func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {} // UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service. @@ -446,6 +494,60 @@ func (x *managementServiceJobServer) Recv() (*EncryptedMessage, error) { return m, nil } +func _ManagementService_CreateExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EncryptedMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).CreateExpose(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/CreateExpose", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).CreateExpose(ctx, req.(*EncryptedMessage)) + } + return interceptor(ctx, in, info, handler) +} + +func _ManagementService_RenewExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EncryptedMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).RenewExpose(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/RenewExpose", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).RenewExpose(ctx, req.(*EncryptedMessage)) + } + return interceptor(ctx, in, info, handler) +} + +func _ManagementService_StopExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EncryptedMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).StopExpose(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/StopExpose", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).StopExpose(ctx, req.(*EncryptedMessage)) + } + return interceptor(ctx, in, info, handler) +} + // ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -481,6 +583,18 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Logout", Handler: _ManagementService_Logout_Handler, }, + { + MethodName: "CreateExpose", + Handler: _ManagementService_CreateExpose_Handler, + }, + { + MethodName: "RenewExpose", + Handler: _ManagementService_RenewExpose_Handler, + }, + { + MethodName: "StopExpose", + Handler: _ManagementService_StopExpose_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 13fcb159e..1095b6411 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v7.34.1 // source: proxy_service.proto package proto @@ -9,6 +9,7 @@ package proto import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" @@ -70,6 +71,52 @@ func (ProxyMappingUpdateType) EnumDescriptor() ([]byte, []int) { return file_proxy_service_proto_rawDescGZIP(), []int{0} } +type PathRewriteMode int32 + +const ( + PathRewriteMode_PATH_REWRITE_DEFAULT PathRewriteMode = 0 + PathRewriteMode_PATH_REWRITE_PRESERVE PathRewriteMode = 1 +) + +// Enum value maps for PathRewriteMode. +var ( + PathRewriteMode_name = map[int32]string{ + 0: "PATH_REWRITE_DEFAULT", + 1: "PATH_REWRITE_PRESERVE", + } + PathRewriteMode_value = map[string]int32{ + "PATH_REWRITE_DEFAULT": 0, + "PATH_REWRITE_PRESERVE": 1, + } +) + +func (x PathRewriteMode) Enum() *PathRewriteMode { + p := new(PathRewriteMode) + *p = x + return p +} + +func (x PathRewriteMode) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (PathRewriteMode) Descriptor() protoreflect.EnumDescriptor { + return file_proxy_service_proto_enumTypes[1].Descriptor() +} + +func (PathRewriteMode) Type() protoreflect.EnumType { + return &file_proxy_service_proto_enumTypes[1] +} + +func (x PathRewriteMode) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use PathRewriteMode.Descriptor instead. +func (PathRewriteMode) EnumDescriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{1} +} + type ProxyStatus int32 const ( @@ -112,11 +159,11 @@ func (x ProxyStatus) String() string { } func (ProxyStatus) Descriptor() protoreflect.EnumDescriptor { - return file_proxy_service_proto_enumTypes[1].Descriptor() + return file_proxy_service_proto_enumTypes[2].Descriptor() } func (ProxyStatus) Type() protoreflect.EnumType { - return &file_proxy_service_proto_enumTypes[1] + return &file_proxy_service_proto_enumTypes[2] } func (x ProxyStatus) Number() protoreflect.EnumNumber { @@ -125,7 +172,75 @@ func (x ProxyStatus) Number() protoreflect.EnumNumber { // Deprecated: Use ProxyStatus.Descriptor instead. func (ProxyStatus) EnumDescriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{1} + return file_proxy_service_proto_rawDescGZIP(), []int{2} +} + +// ProxyCapabilities describes what a proxy can handle. +type ProxyCapabilities struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` + // Whether the proxy requires a subdomain label in front of its cluster domain. + // When true, accounts cannot use the cluster domain bare. + RequireSubdomain *bool `protobuf:"varint,2,opt,name=require_subdomain,json=requireSubdomain,proto3,oneof" json:"require_subdomain,omitempty"` + // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. + SupportsCrowdsec *bool `protobuf:"varint,3,opt,name=supports_crowdsec,json=supportsCrowdsec,proto3,oneof" json:"supports_crowdsec,omitempty"` +} + +func (x *ProxyCapabilities) Reset() { + *x = ProxyCapabilities{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProxyCapabilities) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProxyCapabilities) ProtoMessage() {} + +func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProxyCapabilities.ProtoReflect.Descriptor instead. +func (*ProxyCapabilities) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{0} +} + +func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { + if x != nil && x.SupportsCustomPorts != nil { + return *x.SupportsCustomPorts + } + return false +} + +func (x *ProxyCapabilities) GetRequireSubdomain() bool { + if x != nil && x.RequireSubdomain != nil { + return *x.RequireSubdomain + } + return false +} + +func (x *ProxyCapabilities) GetSupportsCrowdsec() bool { + if x != nil && x.SupportsCrowdsec != nil { + return *x.SupportsCrowdsec + } + return false } // GetMappingUpdateRequest is sent to initialise a mapping stream. @@ -134,16 +249,17 @@ type GetMappingUpdateRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -156,7 +272,7 @@ func (x *GetMappingUpdateRequest) String() string { func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -169,7 +285,7 @@ func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateRequest.ProtoReflect.Descriptor instead. func (*GetMappingUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{0} + return file_proxy_service_proto_rawDescGZIP(), []int{1} } func (x *GetMappingUpdateRequest) GetProxyId() string { @@ -200,6 +316,13 @@ func (x *GetMappingUpdateRequest) GetAddress() string { return "" } +func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + // GetMappingUpdateResponse contains zero or more ProxyMappings. // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. @@ -217,7 +340,7 @@ type GetMappingUpdateResponse struct { func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -230,7 +353,7 @@ func (x *GetMappingUpdateResponse) String() string { func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -243,7 +366,7 @@ func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateResponse.ProtoReflect.Descriptor instead. func (*GetMappingUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{1} + return file_proxy_service_proto_rawDescGZIP(), []int{2} } func (x *GetMappingUpdateResponse) GetMapping() []*ProxyMapping { @@ -260,19 +383,109 @@ func (x *GetMappingUpdateResponse) GetInitialSyncComplete() bool { return false } +type PathTargetOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SkipTlsVerify bool `protobuf:"varint,1,opt,name=skip_tls_verify,json=skipTlsVerify,proto3" json:"skip_tls_verify,omitempty"` + RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` + PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Send PROXY protocol v2 header to this backend. + ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` + // Idle timeout before a UDP session is reaped. + SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` +} + +func (x *PathTargetOptions) Reset() { + *x = PathTargetOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PathTargetOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PathTargetOptions) ProtoMessage() {} + +func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PathTargetOptions.ProtoReflect.Descriptor instead. +func (*PathTargetOptions) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{3} +} + +func (x *PathTargetOptions) GetSkipTlsVerify() bool { + if x != nil { + return x.SkipTlsVerify + } + return false +} + +func (x *PathTargetOptions) GetRequestTimeout() *durationpb.Duration { + if x != nil { + return x.RequestTimeout + } + return nil +} + +func (x *PathTargetOptions) GetPathRewrite() PathRewriteMode { + if x != nil { + return x.PathRewrite + } + return PathRewriteMode_PATH_REWRITE_DEFAULT +} + +func (x *PathTargetOptions) GetCustomHeaders() map[string]string { + if x != nil { + return x.CustomHeaders + } + return nil +} + +func (x *PathTargetOptions) GetProxyProtocol() bool { + if x != nil { + return x.ProxyProtocol + } + return false +} + +func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { + if x != nil { + return x.SessionIdleTimeout + } + return nil +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` - Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` + Options *PathTargetOptions `protobuf:"bytes,3,opt,name=options,proto3" json:"options,omitempty"` } func (x *PathMapping) Reset() { *x = PathMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -285,7 +498,7 @@ func (x *PathMapping) String() string { func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -298,7 +511,7 @@ func (x *PathMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use PathMapping.ProtoReflect.Descriptor instead. func (*PathMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{2} + return file_proxy_service_proto_rawDescGZIP(), []int{4} } func (x *PathMapping) GetPath() string { @@ -315,22 +528,87 @@ func (x *PathMapping) GetTarget() string { return "" } +func (x *PathMapping) GetOptions() *PathTargetOptions { + if x != nil { + return x.Options + } + return nil +} + +type HeaderAuth struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Header name to check, e.g. "Authorization", "X-API-Key". + Header string `protobuf:"bytes,1,opt,name=header,proto3" json:"header,omitempty"` + // argon2id hash of the expected full header value. + HashedValue string `protobuf:"bytes,2,opt,name=hashed_value,json=hashedValue,proto3" json:"hashed_value,omitempty"` +} + +func (x *HeaderAuth) Reset() { + *x = HeaderAuth{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HeaderAuth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeaderAuth) ProtoMessage() {} + +func (x *HeaderAuth) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeaderAuth.ProtoReflect.Descriptor instead. +func (*HeaderAuth) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{5} +} + +func (x *HeaderAuth) GetHeader() string { + if x != nil { + return x.Header + } + return "" +} + +func (x *HeaderAuth) GetHashedValue() string { + if x != nil { + return x.HashedValue + } + return "" +} + type Authentication struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` - MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` - Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` - Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` - Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` + SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` + MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` + Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` + Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` + Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` + HeaderAuths []*HeaderAuth `protobuf:"bytes,6,rep,name=header_auths,json=headerAuths,proto3" json:"header_auths,omitempty"` } func (x *Authentication) Reset() { *x = Authentication{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -343,7 +621,7 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -356,7 +634,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{3} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *Authentication) GetSessionKey() string { @@ -394,6 +672,93 @@ func (x *Authentication) GetOidc() bool { return false } +func (x *Authentication) GetHeaderAuths() []*HeaderAuth { + if x != nil { + return x.HeaderAuths + } + return nil +} + +type AccessRestrictions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AllowedCidrs []string `protobuf:"bytes,1,rep,name=allowed_cidrs,json=allowedCidrs,proto3" json:"allowed_cidrs,omitempty"` + BlockedCidrs []string `protobuf:"bytes,2,rep,name=blocked_cidrs,json=blockedCidrs,proto3" json:"blocked_cidrs,omitempty"` + AllowedCountries []string `protobuf:"bytes,3,rep,name=allowed_countries,json=allowedCountries,proto3" json:"allowed_countries,omitempty"` + BlockedCountries []string `protobuf:"bytes,4,rep,name=blocked_countries,json=blockedCountries,proto3" json:"blocked_countries,omitempty"` + // CrowdSec IP reputation mode: "", "off", "enforce", or "observe". + CrowdsecMode string `protobuf:"bytes,5,opt,name=crowdsec_mode,json=crowdsecMode,proto3" json:"crowdsec_mode,omitempty"` +} + +func (x *AccessRestrictions) Reset() { + *x = AccessRestrictions{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AccessRestrictions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AccessRestrictions) ProtoMessage() {} + +func (x *AccessRestrictions) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AccessRestrictions.ProtoReflect.Descriptor instead. +func (*AccessRestrictions) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{7} +} + +func (x *AccessRestrictions) GetAllowedCidrs() []string { + if x != nil { + return x.AllowedCidrs + } + return nil +} + +func (x *AccessRestrictions) GetBlockedCidrs() []string { + if x != nil { + return x.BlockedCidrs + } + return nil +} + +func (x *AccessRestrictions) GetAllowedCountries() []string { + if x != nil { + return x.AllowedCountries + } + return nil +} + +func (x *AccessRestrictions) GetBlockedCountries() []string { + if x != nil { + return x.BlockedCountries + } + return nil +} + +func (x *AccessRestrictions) GetCrowdsecMode() string { + if x != nil { + return x.CrowdsecMode + } + return "" +} + type ProxyMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -412,12 +777,17 @@ type ProxyMapping struct { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. RewriteRedirects bool `protobuf:"varint,9,opt,name=rewrite_redirects,json=rewriteRedirects,proto3" json:"rewrite_redirects,omitempty"` + // Service mode: "http", "tcp", "udp", or "tls". + Mode string `protobuf:"bytes,10,opt,name=mode,proto3" json:"mode,omitempty"` + // For L4/TLS: the port the proxy listens on. + ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` + AccessRestrictions *AccessRestrictions `protobuf:"bytes,12,opt,name=access_restrictions,json=accessRestrictions,proto3" json:"access_restrictions,omitempty"` } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -430,7 +800,7 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -443,7 +813,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{4} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -509,6 +879,27 @@ func (x *ProxyMapping) GetRewriteRedirects() bool { return false } +func (x *ProxyMapping) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *ProxyMapping) GetListenPort() int32 { + if x != nil { + return x.ListenPort + } + return 0 +} + +func (x *ProxyMapping) GetAccessRestrictions() *AccessRestrictions { + if x != nil { + return x.AccessRestrictions + } + return nil +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { state protoimpl.MessageState @@ -521,7 +912,7 @@ type SendAccessLogRequest struct { func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -534,7 +925,7 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -547,7 +938,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -567,7 +958,7 @@ type SendAccessLogResponse struct { func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -580,7 +971,7 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -593,7 +984,7 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } type AccessLog struct { @@ -614,12 +1005,17 @@ type AccessLog struct { AuthMechanism string `protobuf:"bytes,11,opt,name=auth_mechanism,json=authMechanism,proto3" json:"auth_mechanism,omitempty"` UserId string `protobuf:"bytes,12,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` AuthSuccess bool `protobuf:"varint,13,opt,name=auth_success,json=authSuccess,proto3" json:"auth_success,omitempty"` + BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` + BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` + Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` + // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). + Metadata map[string]string `protobuf:"bytes,17,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *AccessLog) Reset() { *x = AccessLog{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -632,7 +1028,7 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -645,7 +1041,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -739,6 +1135,34 @@ func (x *AccessLog) GetAuthSuccess() bool { return false } +func (x *AccessLog) GetBytesUpload() int64 { + if x != nil { + return x.BytesUpload + } + return 0 +} + +func (x *AccessLog) GetBytesDownload() int64 { + if x != nil { + return x.BytesDownload + } + return 0 +} + +func (x *AccessLog) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *AccessLog) GetMetadata() map[string]string { + if x != nil { + return x.Metadata + } + return nil +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -750,13 +1174,14 @@ type AuthenticateRequest struct { // // *AuthenticateRequest_Password // *AuthenticateRequest_Pin + // *AuthenticateRequest_HeaderAuth Request isAuthenticateRequest_Request `protobuf_oneof:"request"` } func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -769,7 +1194,7 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -782,7 +1207,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *AuthenticateRequest) GetId() string { @@ -820,6 +1245,13 @@ func (x *AuthenticateRequest) GetPin() *PinRequest { return nil } +func (x *AuthenticateRequest) GetHeaderAuth() *HeaderAuthRequest { + if x, ok := x.GetRequest().(*AuthenticateRequest_HeaderAuth); ok { + return x.HeaderAuth + } + return nil +} + type isAuthenticateRequest_Request interface { isAuthenticateRequest_Request() } @@ -832,10 +1264,71 @@ type AuthenticateRequest_Pin struct { Pin *PinRequest `protobuf:"bytes,4,opt,name=pin,proto3,oneof"` } +type AuthenticateRequest_HeaderAuth struct { + HeaderAuth *HeaderAuthRequest `protobuf:"bytes,5,opt,name=header_auth,json=headerAuth,proto3,oneof"` +} + func (*AuthenticateRequest_Password) isAuthenticateRequest_Request() {} func (*AuthenticateRequest_Pin) isAuthenticateRequest_Request() {} +func (*AuthenticateRequest_HeaderAuth) isAuthenticateRequest_Request() {} + +type HeaderAuthRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + HeaderValue string `protobuf:"bytes,1,opt,name=header_value,json=headerValue,proto3" json:"header_value,omitempty"` + HeaderName string `protobuf:"bytes,2,opt,name=header_name,json=headerName,proto3" json:"header_name,omitempty"` +} + +func (x *HeaderAuthRequest) Reset() { + *x = HeaderAuthRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HeaderAuthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeaderAuthRequest) ProtoMessage() {} + +func (x *HeaderAuthRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[13] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeaderAuthRequest.ProtoReflect.Descriptor instead. +func (*HeaderAuthRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{13} +} + +func (x *HeaderAuthRequest) GetHeaderValue() string { + if x != nil { + return x.HeaderValue + } + return "" +} + +func (x *HeaderAuthRequest) GetHeaderName() string { + if x != nil { + return x.HeaderName + } + return "" +} + type PasswordRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -847,7 +1340,7 @@ type PasswordRequest struct { func (x *PasswordRequest) Reset() { *x = PasswordRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -860,7 +1353,7 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -873,7 +1366,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *PasswordRequest) GetPassword() string { @@ -894,7 +1387,7 @@ type PinRequest struct { func (x *PinRequest) Reset() { *x = PinRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -907,7 +1400,7 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -920,7 +1413,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } func (x *PinRequest) GetPin() string { @@ -942,7 +1435,7 @@ type AuthenticateResponse struct { func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -955,7 +1448,7 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -968,7 +1461,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1001,7 +1494,7 @@ type SendStatusUpdateRequest struct { func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1014,7 +1507,7 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1027,7 +1520,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1075,7 +1568,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1088,7 +1581,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1101,7 +1594,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1121,7 +1614,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1134,7 +1627,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1147,7 +1640,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1198,7 +1691,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1211,7 +1704,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1224,7 +1717,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1254,7 +1747,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1267,7 +1760,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1280,7 +1773,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *GetOIDCURLRequest) GetId() string { @@ -1315,7 +1808,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1328,7 +1821,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1341,7 +1834,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{22} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1363,7 +1856,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1376,7 +1869,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1389,7 +1882,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{23} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1420,7 +1913,7 @@ type ValidateSessionResponse struct { func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1433,7 +1926,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1446,7 +1939,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{24} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1482,70 +1975,155 @@ var File_proxy_service_proto protoreflect.FileDescriptor var file_proxy_service_proto_rawDesc = []byte{ 0x0a, 0x13, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x74, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xa3, 0x01, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, - 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, - 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, - 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, - 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0x39, 0x0a, - 0x0b, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, - 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, - 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x22, 0xaa, 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, - 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, - 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, - 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, - 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, - 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, - 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, - 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, - 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0xe0, 0x02, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, - 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, - 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, - 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, - 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, - 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, - 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, - 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, - 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, - 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, + 0x74, 0x6f, 0x22, 0xf6, 0x01, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, + 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, + 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, + 0x01, 0x12, 0x30, 0x0a, 0x11, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x48, 0x01, 0x52, 0x10, + 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x53, 0x75, 0x62, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x11, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, + 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x48, 0x02, + 0x52, 0x10, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x43, 0x72, 0x6f, 0x77, 0x64, 0x73, + 0x65, 0x63, 0x88, 0x01, 0x01, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x42, + 0x14, 0x0a, 0x12, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x5f, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x22, 0xe6, 0x01, 0x0a, 0x17, + 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, + 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, + 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, + 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, + 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, + 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, + 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, + 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, + 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, + 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, + 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, + 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, + 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, + 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x47, + 0x0a, 0x0a, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, + 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x68, 0x65, + 0x61, 0x64, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x61, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x61, 0x73, 0x68, + 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xe5, 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, + 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, + 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, + 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, + 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, + 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, + 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, + 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, + 0x6f, 0x69, 0x64, 0x63, 0x12, 0x39, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x61, + 0x75, 0x74, 0x68, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, + 0x74, 0x68, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x73, 0x22, + 0xdd, 0x01, 0x0a, 0x12, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, + 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x62, + 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x69, 0x64, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0c, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x43, 0x69, 0x64, 0x72, 0x73, + 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x72, 0x69, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x61, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x2b, 0x0a, + 0x11, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, + 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, + 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, + 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0c, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x4d, 0x6f, 0x64, 0x65, 0x22, + 0xe6, 0x03, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, + 0x2b, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x61, 0x75, 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, + 0x75, 0x74, 0x68, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, + 0x61, 0x73, 0x73, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, + 0x5f, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x10, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, + 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x4f, 0x0a, 0x13, 0x61, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x5f, 0x72, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x12, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, + 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0xa0, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, + 0x73, 0x65, 0x22, 0x84, 0x05, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, @@ -1571,148 +2149,176 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, - 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, - 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, - 0x70, 0x69, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x2d, - 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, - 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, - 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, - 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, - 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, + 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, + 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, + 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, + 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, + 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, + 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, + 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, + 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, + 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x68, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, + 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x2d, 0x0a, + 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, + 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, + 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, + 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, + 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, + 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, + 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, + 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, + 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, 0x6e, + 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, - 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, - 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, - 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, - 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, - 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, - 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, - 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, - 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, - 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, - 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, - 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, - 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, - 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, - 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, - 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, - 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, - 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, - 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, - 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, - 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, - 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, - 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, - 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, - 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, - 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, - 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, - 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, - 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0xc8, 0x01, - 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, - 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, - 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, - 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, - 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, - 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, - 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, - 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, - 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, - 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, - 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, - 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, - 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, - 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, - 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, - 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, - 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, - 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, - 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, + 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, 0x62, + 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, + 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, + 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, + 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, + 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, + 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, + 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, + 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, + 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, + 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, + 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, + 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, + 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, + 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, + 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, + 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, + 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, + 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, + 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, + 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, + 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, + 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, + 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, + 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, + 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, + 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, + 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, + 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, + 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, + 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, + 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, + 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, + 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, + 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, + 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, + 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, + 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, + 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, + 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1727,63 +2333,82 @@ func file_proxy_service_proto_rawDescGZIP() []byte { return file_proxy_service_proto_rawDescData } -var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 20) +var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 27) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType - (ProxyStatus)(0), // 1: management.ProxyStatus - (*GetMappingUpdateRequest)(nil), // 2: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 3: management.GetMappingUpdateResponse - (*PathMapping)(nil), // 4: management.PathMapping - (*Authentication)(nil), // 5: management.Authentication - (*ProxyMapping)(nil), // 6: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 7: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 8: management.SendAccessLogResponse - (*AccessLog)(nil), // 9: management.AccessLog - (*AuthenticateRequest)(nil), // 10: management.AuthenticateRequest - (*PasswordRequest)(nil), // 11: management.PasswordRequest - (*PinRequest)(nil), // 12: management.PinRequest - (*AuthenticateResponse)(nil), // 13: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 14: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 15: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 16: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 17: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 18: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 19: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 20: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 21: management.ValidateSessionResponse - (*timestamppb.Timestamp)(nil), // 22: google.protobuf.Timestamp + (PathRewriteMode)(0), // 1: management.PathRewriteMode + (ProxyStatus)(0), // 2: management.ProxyStatus + (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 6: management.PathTargetOptions + (*PathMapping)(nil), // 7: management.PathMapping + (*HeaderAuth)(nil), // 8: management.HeaderAuth + (*Authentication)(nil), // 9: management.Authentication + (*AccessRestrictions)(nil), // 10: management.AccessRestrictions + (*ProxyMapping)(nil), // 11: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 12: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 13: management.SendAccessLogResponse + (*AccessLog)(nil), // 14: management.AccessLog + (*AuthenticateRequest)(nil), // 15: management.AuthenticateRequest + (*HeaderAuthRequest)(nil), // 16: management.HeaderAuthRequest + (*PasswordRequest)(nil), // 17: management.PasswordRequest + (*PinRequest)(nil), // 18: management.PinRequest + (*AuthenticateResponse)(nil), // 19: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 20: management.SendStatusUpdateRequest + (*SendStatusUpdateResponse)(nil), // 21: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 22: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 23: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 24: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 25: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 26: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 27: management.ValidateSessionResponse + nil, // 28: management.PathTargetOptions.CustomHeadersEntry + nil, // 29: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 30: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 31: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 22, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp - 6, // 1: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 0, // 2: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 4, // 3: management.ProxyMapping.path:type_name -> management.PathMapping - 5, // 4: management.ProxyMapping.auth:type_name -> management.Authentication - 9, // 5: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 22, // 6: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 11, // 7: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 12, // 8: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 1, // 9: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 2, // 10: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 7, // 11: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 10, // 12: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 14, // 13: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 16, // 14: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 18, // 15: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 20, // 16: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 3, // 17: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 8, // 18: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 13, // 19: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 15, // 20: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 17, // 21: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 19, // 22: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 21, // 23: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 17, // [17:24] is the sub-list for method output_type - 10, // [10:17] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name + 30, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities + 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 31, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode + 28, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 31, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions + 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth + 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 7, // 10: management.ProxyMapping.path:type_name -> management.PathMapping + 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication + 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions + 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 30, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 29, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest + 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 20: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 12, // 21: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 22: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 23: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 22, // 24: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 24, // 25: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 26, // 26: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 27: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 13, // 28: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 29: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 21, // 30: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 23, // 31: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 25, // 32: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 27, // 33: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 27, // [27:34] is the sub-list for method output_type + 20, // [20:27] is the sub-list for method input_type + 20, // [20:20] is the sub-list for extension type_name + 20, // [20:20] is the sub-list for extension extendee + 0, // [0:20] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -1793,7 +2418,7 @@ func file_proxy_service_proto_init() { } if !protoimpl.UnsafeEnabled { file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateRequest); i { + switch v := v.(*ProxyCapabilities); i { case 0: return &v.state case 1: @@ -1805,7 +2430,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateResponse); i { + switch v := v.(*GetMappingUpdateRequest); i { case 0: return &v.state case 1: @@ -1817,7 +2442,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { + switch v := v.(*GetMappingUpdateResponse); i { case 0: return &v.state case 1: @@ -1829,7 +2454,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { + switch v := v.(*PathTargetOptions); i { case 0: return &v.state case 1: @@ -1841,7 +2466,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { + switch v := v.(*PathMapping); i { case 0: return &v.state case 1: @@ -1853,7 +2478,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { + switch v := v.(*HeaderAuth); i { case 0: return &v.state case 1: @@ -1865,7 +2490,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { + switch v := v.(*Authentication); i { case 0: return &v.state case 1: @@ -1877,7 +2502,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { + switch v := v.(*AccessRestrictions); i { case 0: return &v.state case 1: @@ -1889,7 +2514,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { + switch v := v.(*ProxyMapping); i { case 0: return &v.state case 1: @@ -1901,7 +2526,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { + switch v := v.(*SendAccessLogRequest); i { case 0: return &v.state case 1: @@ -1913,7 +2538,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { + switch v := v.(*SendAccessLogResponse); i { case 0: return &v.state case 1: @@ -1925,7 +2550,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { + switch v := v.(*AccessLog); i { case 0: return &v.state case 1: @@ -1937,7 +2562,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { + switch v := v.(*AuthenticateRequest); i { case 0: return &v.state case 1: @@ -1949,7 +2574,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*HeaderAuthRequest); i { case 0: return &v.state case 1: @@ -1961,7 +2586,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*PasswordRequest); i { case 0: return &v.state case 1: @@ -1973,7 +2598,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*PinRequest); i { case 0: return &v.state case 1: @@ -1985,7 +2610,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*AuthenticateResponse); i { case 0: return &v.state case 1: @@ -1997,7 +2622,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*SendStatusUpdateRequest); i { case 0: return &v.state case 1: @@ -2009,7 +2634,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -2021,6 +2646,66 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateProxyPeerRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateProxyPeerResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetOIDCURLRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetOIDCURLResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state @@ -2033,19 +2718,21 @@ func file_proxy_service_proto_init() { } } } - file_proxy_service_proto_msgTypes[8].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), + (*AuthenticateRequest_HeaderAuth)(nil), } - file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[15].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[20].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, - NumEnums: 2, - NumMessages: 20, + NumEnums: 3, + NumMessages: 27, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index b4e62a52a..e359f0cbd 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -4,6 +4,7 @@ package management; option go_package = "/proto"; +import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; // ProxyService - Management is the SERVER, Proxy is the CLIENT @@ -26,12 +27,24 @@ service ProxyService { rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); } +// ProxyCapabilities describes what a proxy can handle. +message ProxyCapabilities { + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + optional bool supports_custom_ports = 1; + // Whether the proxy requires a subdomain label in front of its cluster domain. + // When true, accounts cannot use the cluster domain bare. + optional bool require_subdomain = 2; + // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. + optional bool supports_crowdsec = 3; +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. message GetMappingUpdateRequest { string proxy_id = 1; string version = 2; google.protobuf.Timestamp started_at = 3; string address = 4; + ProxyCapabilities capabilities = 5; } // GetMappingUpdateResponse contains zero or more ProxyMappings. @@ -50,9 +63,33 @@ enum ProxyMappingUpdateType { UPDATE_TYPE_REMOVED = 2; } +enum PathRewriteMode { + PATH_REWRITE_DEFAULT = 0; + PATH_REWRITE_PRESERVE = 1; +} + +message PathTargetOptions { + bool skip_tls_verify = 1; + google.protobuf.Duration request_timeout = 2; + PathRewriteMode path_rewrite = 3; + map custom_headers = 4; + // Send PROXY protocol v2 header to this backend. + bool proxy_protocol = 5; + // Idle timeout before a UDP session is reaped. + google.protobuf.Duration session_idle_timeout = 6; +} + message PathMapping { string path = 1; string target = 2; + PathTargetOptions options = 3; +} + +message HeaderAuth { + // Header name to check, e.g. "Authorization", "X-API-Key". + string header = 1; + // argon2id hash of the expected full header value. + string hashed_value = 2; } message Authentication { @@ -61,6 +98,16 @@ message Authentication { bool password = 3; bool pin = 4; bool oidc = 5; + repeated HeaderAuth header_auths = 6; +} + +message AccessRestrictions { + repeated string allowed_cidrs = 1; + repeated string blocked_cidrs = 2; + repeated string allowed_countries = 3; + repeated string blocked_countries = 4; + // CrowdSec IP reputation mode: "", "off", "enforce", or "observe". + string crowdsec_mode = 5; } message ProxyMapping { @@ -77,6 +124,11 @@ message ProxyMapping { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. bool rewrite_redirects = 9; + // Service mode: "http", "tcp", "udp", or "tls". + string mode = 10; + // For L4/TLS: the port the proxy listens on. + int32 listen_port = 11; + AccessRestrictions access_restrictions = 12; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -101,6 +153,11 @@ message AccessLog { string auth_mechanism = 11; string user_id = 12; bool auth_success = 13; + int64 bytes_upload = 14; + int64 bytes_download = 15; + string protocol = 16; + // Extra key-value metadata for the access log entry (e.g. crowdsec_verdict, scenario). + map metadata = 17; } message AuthenticateRequest { @@ -109,9 +166,15 @@ message AuthenticateRequest { oneof request { PasswordRequest password = 3; PinRequest pin = 4; + HeaderAuthRequest header_auth = 5; } } +message HeaderAuthRequest { + string header_value = 1; + string header_name = 2; +} + message PasswordRequest { string password = 1; } diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index ed1b63435..b10b05617 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial() + conn, err := rd.Dial(ctx) if err != nil { return nil, err } diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 78462837d..2d7b00a80 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -89,12 +89,12 @@ func prepareURL(address string) (string, error) { finalHost, finalPort, err := net.SplitHostPort(host) if err != nil { if strings.Contains(err.Error(), "missing port") { - return host + ":" + defaultPort, nil + return net.JoinHostPort(strings.Trim(host, "[]"), defaultPort), nil } // return any other split error as is return "", err } - return finalHost + ":" + finalPort, nil + return net.JoinHostPort(finalHost, finalPort), nil } diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 0550fc63e..34359d17e 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } -func (r *RaceDial) Dial() (net.Conn, error) { +func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) - abortCtx, abort := context.WithCancel(context.Background()) + abortCtx, abort := context.WithCancel(ctx) defer abort() for _, dfn := range r.dialerFns { diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index d216ec5e7..aa18df578 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { serverURL := "test.server.com" rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error with empty dialers, got nil") } @@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) { } rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } diff --git a/shared/relay/client/early_msg_buffer.go b/shared/relay/client/early_msg_buffer.go index 3ead94de1..52ff4d42e 100644 --- a/shared/relay/client/early_msg_buffer.go +++ b/shared/relay/client/early_msg_buffer.go @@ -65,8 +65,8 @@ func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool { } entry := earlyMsg{ - peerID: peerID, - msg: msg, + peerID: peerID, + msg: msg, createdAt: time.Now(), } elem := b.order.PushBack(entry) diff --git a/tools/idp-migrate/DEVELOPMENT.md b/tools/idp-migrate/DEVELOPMENT.md new file mode 100644 index 000000000..5697ead40 --- /dev/null +++ b/tools/idp-migrate/DEVELOPMENT.md @@ -0,0 +1,209 @@ +# IdP Migration Tool — Developer Guide + +## Overview + +This tool migrates NetBird deployments from an external IdP (Auth0, Zitadel, Okta, etc.) to the embedded Dex IdP introduced in v0.62.0. It does two things: + +1. **DB migration** — Re-encodes every user ID from `{original_id}` to Dex's protobuf-encoded format `base64(proto{original_id, connector_id})`. +2. **Config generation** — Transforms `management.json`: removes `IdpManagerConfig`, `PKCEAuthorizationFlow`, and `DeviceAuthorizationFlow`; strips `HttpConfig` to only `CertFile`/`CertKey`; adds `EmbeddedIdP` with the static connector configuration. + +## Code Layout + +``` +tools/idp-migrate/ +├── config.go # migrationConfig struct, CLI flags, env vars, validation +├── main.go # CLI entry point, migration phases, config generation +├── main_test.go # 8 test functions (18 subtests) covering config, connector, URL builder, config generation +└── DEVELOPMENT.md # this file + +management/server/idp/migration/ +├── migration.go # Server interface, MigrateUsersToStaticConnectors(), PopulateUserInfo(), migrateUser(), reconcileActivityStore() +├── migration_test.go # 6 top-level tests (with subtests) using hand-written mocks +└── store.go # Store, EventStore interfaces, SchemaCheck, RequiredSchema, SchemaError types + +management/server/store/ +└── sql_store_idp_migration.go # CheckSchema(), ListUsers(), UpdateUserInfo(), UpdateUserID(), txDeferFKConstraints() on SqlStore + +management/server/activity/store/ +├── sql_store_idp_migration.go # UpdateUserID() on activity Store +└── sql_store_idp_migration_test.go # 5 subtests for activity UpdateUserID + +``` + +## Release / Distribution + +The tool is included in `.goreleaser.yaml` as the `netbird-idp-migrate` build target. Each NetBird release produces pre-built archives for Linux (amd64, arm64, arm) that are uploaded to GitHub Releases. The archive naming convention is: + +``` +netbird-idp-migrate__linux_.tar.gz +``` + +The build requires `CGO_ENABLED=1` because it links the SQLite driver used by `SqlStore`. The cross-compilation setup (CC env for arm64/arm) mirrors the `netbird-mgmt` build. + +## CLI Flags + +| Flag | Type | Default | Description | +|------|------|---------|-------------| +| `--config` | string | *(required)* | Path to management.json | +| `--datadir` | string | *(required)* | Data directory (containing store.db / events.db) | +| `--idp-seed-info` | string | *(required)* | Base64-encoded connector JSON | +| `--domain` | string | `""` | Sets both dashboard and API domain (convenience shorthand) | +| `--dashboard-domain` | string | *(required)* | Dashboard domain (for redirect URIs) | +| `--api-domain` | string | *(required)* | API domain (for Dex issuer and callback URLs) | +| `--dry-run` | bool | `false` | Preview changes without writing | +| `--force` | bool | `false` | Skip interactive confirmation prompt | +| `--skip-config` | bool | `false` | Skip config generation (DB-only migration) | +| `--skip-populate-user-info` | bool | `false` | Skip populating user info (user ID migration only) | +| `--log-level` | string | `"info"` | Log level (debug, info, warn, error) | + +## Environment Variables + +All flags can be overridden via environment variables. Env vars take precedence over flags. + +| Env Var | Overrides | +|---------|-----------| +| `NETBIRD_DOMAIN` | Sets both `--dashboard-domain` and `--api-domain` | +| `NETBIRD_API_URL` | `--api-domain` | +| `NETBIRD_DASHBOARD_URL` | `--dashboard-domain` | +| `NETBIRD_CONFIG_PATH` | `--config` | +| `NETBIRD_DATA_DIR` | `--datadir` | +| `NETBIRD_IDP_SEED_INFO` | `--idp-seed-info` | +| `NETBIRD_DRY_RUN` | `--dry-run` (set to `"true"`) | +| `NETBIRD_FORCE` | `--force` (set to `"true"`) | +| `NETBIRD_SKIP_CONFIG` | `--skip-config` (set to `"true"`) | +| `NETBIRD_SKIP_POPULATE_USER_INFO` | `--skip-populate-user-info` (set to `"true"`) | +| `NETBIRD_LOG_LEVEL` | `--log-level` | + +Resolution order: CLI flags are parsed first, then `--domain` sets both URLs, then `NETBIRD_DOMAIN` overrides both, then `NETBIRD_API_URL` / `NETBIRD_DASHBOARD_URL` override individually. After all resolution, `validateConfig()` ensures all required fields are set. + +## Migration Flow + +### Phase 0: Schema Validation + +`validateSchema()` opens the store and calls `CheckSchema(RequiredSchema)` to verify that all tables and columns required by the migration exist in the database. If anything is missing, the tool exits with a descriptive error instructing the operator to start the management server (v0.66.4+) at least once so that automatic GORM migrations create the required schema. + +### Phase 1: Populate User Info + +Unless `--skip-populate-user-info` is set, `populateUserInfoFromIDP()` runs before connector resolution: + +1. Creates an IDP manager from the existing `IdpManagerConfig` in management.json. +2. Calls `idpManager.GetAllAccounts()` to fetch email and name for all users from the external IDP. +3. Calls `migration.PopulateUserInfo()` which iterates over all store users, skipping service users and users that already have both email and name populated. For Dex-encoded user IDs, it decodes back to the original IDP ID for lookup. +4. Updates the store with any missing email/name values. + +This ensures user contact info is preserved before the ID migration makes the original IDP IDs inaccessible. + +### Phase 2: Connector Decoding + +`decodeConnectorConfig()` base64-decodes and JSON-unmarshals the connector JSON provided via `--idp-seed-info` (or `NETBIRD_IDP_SEED_INFO`). It validates that the connector ID is non-empty. There is no auto-detection or fallback — the operator must provide the full connector configuration. + +### Phase 3: DB Migration + +`migrateDB()` orchestrates the database migration: + +1. `openStores()` opens the main store (`SqlStore`) and activity store (non-fatal if missing). +2. Type-asserts both to `migration.Store` / `migration.EventStore`. +3. `previewUsers()` scans all users — counts pending vs already-migrated (using `DecodeDexUserID`). +4. `confirmPrompt()` asks for interactive confirmation (unless `--force` or `--dry-run`). +5. Calls `migration.MigrateUsersToStaticConnectors(srv, conn)`: + - **Reconciliation pass**: fixes activity store references for users already migrated in the main DB but whose events still reference old IDs (from a previous partial failure). + - **Main loop**: for each non-migrated user, calls `migrateUser()` which atomically updates the user ID in both the main store and activity store. + - **Dry-run**: logs what would happen, skips all writes. + +`SqlStore.UpdateUserID()` atomically updates the user's primary key and all foreign key references (peers, PATs, groups, policies, jobs, etc.) in a single transaction. + +### Phase 4: Config Generation + +Unless `--skip-config` is set, `generateConfig()` runs: + +1. **Read** — loads existing `management.json` as raw JSON to preserve unknown fields. + +2. **Strip** — removes keys that are no longer needed: + - `IdpManagerConfig` + - `PKCEAuthorizationFlow` + - `DeviceAuthorizationFlow` + - All `HttpConfig` fields except `CertFile` and `CertKey` + +3. **Add EmbeddedIdP** — inserts a minimal section with: + - `Enabled: true` + - `Issuer` built from `--api-domain` + `/oauth2` + - `DashboardRedirectURIs` built from `--dashboard-domain` + `/nb-auth` and `/nb-silent-auth` + - `StaticConnectors` containing the decoded connector, with `redirectURI` overridden to `--api-domain` + `/oauth2/callback` + +4. **Write** — backs up original as `management.json.bak`, writes new config. In dry-run mode, prints to stdout instead. + +## Interface Decoupling + +Migration methods (`ListUsers`, `UpdateUserID`) are **not** on the core `store.Store` or `activity.Store` interfaces. Instead, they're defined in `migration/store.go`: + +```go +type Store interface { + ListUsers(ctx context.Context) ([]*types.User, error) + UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error + UpdateUserInfo(ctx context.Context, userID, email, name string) error + CheckSchema(checks []SchemaCheck) []SchemaError +} + +type EventStore interface { + UpdateUserID(ctx context.Context, oldUserID, newUserID string) error +} +``` + +A `Server` interface wraps both stores for dependency injection: + +```go +type Server interface { + Store() Store + EventStore() EventStore // may return nil +} +``` + +The concrete `SqlStore` types already have these methods (in their respective `sql_store_idp_migration.go` files), so they satisfy the interfaces via Go's structural typing — zero changes needed on the core store interfaces. At runtime, the standalone tool type-asserts: + +```go +migStore, ok := mainStore.(migration.Store) +``` + +This keeps migration concerns completely separate from the core store contract. + +## Dex User ID Encoding + +`EncodeDexUserID(userID, connectorID)` produces a manually-encoded protobuf with two string fields, then base64-encodes the result (raw, no padding). `DecodeDexUserID` reverses this. The migration loop uses `DecodeDexUserID` to detect already-migrated users (decode succeeds → skip). + +See `idp/dex/provider.go` for the implementation. + +## Standalone Tool + +The standalone tool (`tools/idp-migrate/main.go`) is the primary migration entry point. It opens stores directly, runs schema validation, populates user info from the external IDP, migrates user IDs, and generates the new config — then exits. Configuration is handled entirely through `config.go` which parses CLI flags and environment variables. + +## Running Tests + +```bash +# Migration library +go test -v ./management/server/idp/migration/... + +# Standalone tool +go test -v ./tools/idp-migrate/... + +# Activity store migration tests +go test -v -run TestUpdateUserID ./management/server/activity/store/... + +# Build locally +go build ./tools/idp-migrate/ +``` + +## Clean Removal + +When migration tooling is no longer needed, delete: + +1. `tools/idp-migrate/` — entire directory +2. `management/server/idp/migration/` — entire directory +3. `management/server/store/sql_store_idp_migration.go` — migration methods on main SqlStore +4. `management/server/activity/store/sql_store_idp_migration.go` — migration method on activity Store +5. `management/server/activity/store/sql_store_idp_migration_test.go` — tests for the above +6. In `.goreleaser.yaml`: + - Remove the `netbird-idp-migrate` build entry + - Remove the `netbird-idp-migrate` archive entry +7. Run `go mod tidy` + +No core interfaces or mocks need editing — that's the point of the decoupling. diff --git a/tools/idp-migrate/LICENSE b/tools/idp-migrate/LICENSE new file mode 100644 index 000000000..be3f7b28e --- /dev/null +++ b/tools/idp-migrate/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/tools/idp-migrate/config.go b/tools/idp-migrate/config.go new file mode 100644 index 000000000..f4d6b9ea2 --- /dev/null +++ b/tools/idp-migrate/config.go @@ -0,0 +1,174 @@ +package main + +import ( + "flag" + "fmt" + "os" + "strconv" + + "github.com/netbirdio/netbird/util" +) + +type migrationConfig struct { + // Data + dashboardURL string + apiURL string + configPath string + dataDir string + idpSeedInfo string + + // Options + dryRun bool + force bool + skipConfig bool + skipPopulateUserInfo bool + + // Logging + logLevel string +} + +func config() (*migrationConfig, error) { + cfg, err := configFromArgs(os.Args[1:]) + if err != nil { + return nil, err + } + + if err := util.InitLog(cfg.logLevel, util.LogConsole); err != nil { + return nil, fmt.Errorf("init logger: %w", err) + } + + return cfg, nil +} + +func configFromArgs(args []string) (*migrationConfig, error) { + var cfg migrationConfig + var domain string + + fs := flag.NewFlagSet("netbird-idp-migrate", flag.ContinueOnError) + fs.StringVar(&domain, "domain", "", "domain for both dashboard and API") + fs.StringVar(&cfg.dashboardURL, "dashboard-url", "", "dashboard URL") + fs.StringVar(&cfg.apiURL, "api-url", "", "API URL") + fs.StringVar(&cfg.configPath, "config", "", "path to management.json (required)") + fs.StringVar(&cfg.dataDir, "datadir", "", "override data directory from config") + fs.StringVar(&cfg.idpSeedInfo, "idp-seed-info", "", "base64-encoded connector JSON (overrides auto-detection)") + fs.BoolVar(&cfg.dryRun, "dry-run", false, "preview changes without writing") + fs.BoolVar(&cfg.force, "force", false, "skip confirmation prompt") + fs.BoolVar(&cfg.skipConfig, "skip-config", false, "skip config generation (DB migration only)") + fs.BoolVar(&cfg.skipPopulateUserInfo, "skip-populate-user-info", false, "skip populating user info (user id migration only)") + fs.StringVar(&cfg.logLevel, "log-level", "info", "log level (debug, info, warn, error)") + + if err := fs.Parse(args); err != nil { + return nil, err + } + + applyOverrides(&cfg, domain) + + if err := validateConfig(&cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +// applyOverrides resolves domain configuration from broad to narrow sources. +// The most granular value always wins: +// +// --domain flag (broadest, only fills blanks) +// NETBIRD_DOMAIN env (overrides flags, sets both) +// --api-domain / --dashboard-domain flags (more specific than --domain) +// NETBIRD_API_URL / NETBIRD_DASHBOARD_URL env (most specific, always wins) +// +// Other env vars unconditionally override their corresponding flags. +func applyOverrides(cfg *migrationConfig, domain string) { + // --domain is a convenience shorthand: only fills in values not already + // set by the more specific --api-domain / --dashboard-domain flags. + if domain != "" { + if cfg.apiURL == "" { + cfg.apiURL = domain + } + if cfg.dashboardURL == "" { + cfg.dashboardURL = domain + } + } + + // Env vars override flags. Broad env var first, then narrow ones on top, + // so the most granular value always wins. + if val, ok := os.LookupEnv("NETBIRD_DOMAIN"); ok { + cfg.dashboardURL = val + cfg.apiURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_API_URL"); ok { + cfg.apiURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_DASHBOARD_URL"); ok { + cfg.dashboardURL = val + } + + if val, ok := os.LookupEnv("NETBIRD_CONFIG_PATH"); ok { + cfg.configPath = val + } + + if val, ok := os.LookupEnv("NETBIRD_DATA_DIR"); ok { + cfg.dataDir = val + } + + if val, ok := os.LookupEnv("NETBIRD_IDP_SEED_INFO"); ok { + cfg.idpSeedInfo = val + } + + // Enforce dry run if any value is provided + if sval, ok := os.LookupEnv("NETBIRD_DRY_RUN"); ok { + if val, err := strconv.ParseBool(sval); err == nil { + cfg.dryRun = val + } + } + + cfg.dryRun = parseBool("NETBIRD_DRY_RUN", cfg.dryRun) + cfg.force = parseBool("NETBIRD_FORCE", cfg.force) + cfg.skipConfig = parseBool("NETBIRD_SKIP_CONFIG", cfg.skipConfig) + cfg.skipPopulateUserInfo = parseBool("NETBIRD_SKIP_POPULATE_USER_INFO", cfg.skipPopulateUserInfo) + + if val, ok := os.LookupEnv("NETBIRD_LOG_LEVEL"); ok { + cfg.logLevel = val + } +} + +func parseBool(varName string, defaultVal bool) bool { + stringValue, ok := os.LookupEnv(varName) + if !ok { + return defaultVal + } + + boolValue, err := strconv.ParseBool(stringValue) + if err != nil { + return defaultVal + } + + return boolValue +} + +func validateConfig(cfg *migrationConfig) error { + if cfg.configPath == "" { + return fmt.Errorf("--config is required") + } + + if cfg.dataDir == "" { + return fmt.Errorf("--datadir is required") + } + + if cfg.idpSeedInfo == "" { + return fmt.Errorf("--idp-seed-info is required") + } + + if cfg.apiURL == "" { + return fmt.Errorf("--api-domain is required") + } + + if cfg.dashboardURL == "" { + return fmt.Errorf("--dashboard-domain is required") + } + + return nil +} diff --git a/tools/idp-migrate/main.go b/tools/idp-migrate/main.go new file mode 100644 index 000000000..a8cba0750 --- /dev/null +++ b/tools/idp-migrate/main.go @@ -0,0 +1,449 @@ +// Package main provides a standalone CLI tool to migrate user IDs from an +// external IdP format to the embedded Dex IdP format used by NetBird >= v0.62.0. +// +// This tool reads management.json to auto-detect the current external IdP +// configuration (issuer, clientID, clientSecret, type) and re-encodes all user +// IDs in the database to the Dex protobuf-encoded format. It works independently +// of migrate.sh and the combined server, allowing operators to migrate their +// database before switching to the combined server. +// +// Usage: +// +// netbird-idp-migrate --config /etc/netbird/management.json [--dry-run] [--force] +package main + +import ( + "bufio" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "maps" + "net/url" + "os" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + activitystore "github.com/netbirdio/netbird/management/server/activity/store" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/idp/migration" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/crypt" +) + +// migrationServer implements migration.Server by wrapping the migration-specific interfaces. +type migrationServer struct { + store migration.Store + eventStore migration.EventStore +} + +func (s *migrationServer) Store() migration.Store { return s.store } +func (s *migrationServer) EventStore() migration.EventStore { return s.eventStore } + +func main() { + cfg, err := config() + if err != nil { + log.Fatalf("config error: %v", err) + } + + if err := run(cfg); err != nil { + log.Fatalf("migration failed: %v", err) + } + + if !cfg.dryRun { + printPostMigrationInstructions(cfg) + } +} + +func run(cfg *migrationConfig) error { + mgmtConfig := &nbconfig.Config{} + if _, err := util.ReadJsonWithEnvSub(cfg.configPath, mgmtConfig); err != nil { + return err + } + + // Validate the database schema before attempting any operations. + if err := validateSchema(mgmtConfig, cfg.dataDir); err != nil { + return err + } + + if !cfg.skipPopulateUserInfo { + err := populateUserInfoFromIDP(cfg, mgmtConfig) + if err != nil { + return fmt.Errorf("populate user info: %w", err) + } + } + + connectorConfig, err := decodeConnectorConfig(cfg.idpSeedInfo) + if err != nil { + return fmt.Errorf("resolve connector: %w", err) + } + + log.Infof( + "resolved connector: type=%s, id=%s, name=%s", + connectorConfig.Type, + connectorConfig.ID, + connectorConfig.Name, + ) + + if err := migrateDB(cfg, mgmtConfig, connectorConfig); err != nil { + return err + } + + if cfg.skipConfig { + log.Info("skipping config generation (--skip-config)") + return nil + } + + return generateConfig(cfg, connectorConfig) +} + +// validateSchema opens the store and checks that all required tables and columns +// exist. If anything is missing, it returns a descriptive error telling the user +// to upgrade their management server. +func validateSchema(mgmtConfig *nbconfig.Config, dataDir string) error { + ctx := context.Background() + migStore, migEventStore, cleanup, err := openStores(ctx, mgmtConfig, dataDir) + if err != nil { + return err + } + defer cleanup() + + errs := migStore.CheckSchema(migration.RequiredSchema) + if len(errs) > 0 { + return fmt.Errorf("%s", formatSchemaErrors(errs)) + } + + if migEventStore != nil { + eventErrs := migEventStore.CheckSchema(migration.RequiredEventSchema) + if len(eventErrs) > 0 { + return fmt.Errorf("activity store schema check failed (upgrade management server first):\n%s", formatSchemaErrors(eventErrs)) + } + } + + log.Info("database schema check passed") + return nil +} + +// formatSchemaErrors returns a user-friendly message listing all missing schema +// elements and instructing the operator to upgrade. +func formatSchemaErrors(errs []migration.SchemaError) string { + var b strings.Builder + b.WriteString("database schema is incomplete — the following tables/columns are missing:\n") + for _, e := range errs { + fmt.Fprintf(&b, " - %s\n", e.String()) + } + b.WriteString("\nPlease start the NetBird management server (v0.66.4+) at least once so that automatic database migrations create the required schema, then re-run this tool.\n") + return b.String() +} + +// populateUserInfoFromIDP creates an IDP manager from the config, fetches all +// user data (email, name) from the external IDP, and updates the store for users +// that are missing this information. +func populateUserInfoFromIDP(cfg *migrationConfig, mgmtConfig *nbconfig.Config) error { + ctx := context.Background() + + if mgmtConfig.IdpManagerConfig == nil { + return fmt.Errorf("IdpManagerConfig is not set in management.json; cannot fetch user info from IDP") + } + + idpManager, err := idp.NewManager(ctx, *mgmtConfig.IdpManagerConfig, nil) + if err != nil { + return fmt.Errorf("create IDP manager: %w", err) + } + if idpManager == nil { + return fmt.Errorf("IDP manager type is 'none' or empty; cannot fetch user info") + } + + log.Infof("created IDP manager (type: %s)", mgmtConfig.IdpManagerConfig.ManagerType) + + migStore, _, cleanup, err := openStores(ctx, mgmtConfig, cfg.dataDir) + if err != nil { + return err + } + defer cleanup() + + srv := &migrationServer{store: migStore} + return migration.PopulateUserInfo(srv, idpManager, cfg.dryRun) +} + +// openStores opens the main and activity stores, returning migration-specific interfaces. +// The caller must call the returned cleanup function to close the stores. +func openStores(ctx context.Context, cfg *nbconfig.Config, dataDir string) (migration.Store, migration.EventStore, func(), error) { + engine := cfg.StoreConfig.Engine + if engine == "" { + engine = types.SqliteStoreEngine + } + + mainStore, err := store.NewStore(ctx, engine, dataDir, nil, true) + if err != nil { + return nil, nil, nil, fmt.Errorf("open main store: %w", err) + } + + if cfg.DataStoreEncryptionKey != "" { + fieldEncrypt, err := crypt.NewFieldEncrypt(cfg.DataStoreEncryptionKey) + if err != nil { + _ = mainStore.Close(ctx) + return nil, nil, nil, fmt.Errorf("init field encryption: %w", err) + } + mainStore.SetFieldEncrypt(fieldEncrypt) + } + + migStore, ok := mainStore.(migration.Store) + if !ok { + _ = mainStore.Close(ctx) + return nil, nil, nil, fmt.Errorf("store does not support migration operations (ListUsers/UpdateUserID)") + } + + cleanup := func() { _ = mainStore.Close(ctx) } + + var migEventStore migration.EventStore + actStore, err := activitystore.NewSqlStore(ctx, dataDir, cfg.DataStoreEncryptionKey) + if err != nil { + log.Warnf("could not open activity store (events.db may not exist): %v", err) + } else { + migEventStore = actStore + prevCleanup := cleanup + cleanup = func() { _ = actStore.Close(ctx); prevCleanup() } + } + + return migStore, migEventStore, cleanup, nil +} + +// migrateDB opens the stores, previews pending users, and runs the DB migration. +func migrateDB(cfg *migrationConfig, mgmtConfig *nbconfig.Config, connectorConfig *dex.Connector) error { + ctx := context.Background() + + migStore, migEventStore, cleanup, err := openStores(ctx, mgmtConfig, cfg.dataDir) + if err != nil { + return err + } + defer cleanup() + + pending, err := previewUsers(ctx, migStore) + if err != nil { + return err + } + + if cfg.dryRun { + if err := os.Setenv("NB_IDP_MIGRATION_DRY_RUN", "true"); err != nil { + return fmt.Errorf("set dry-run env: %w", err) + } + defer os.Unsetenv("NB_IDP_MIGRATION_DRY_RUN") //nolint:errcheck + } + + if !cfg.dryRun && !cfg.force { + if !confirmPrompt(pending) { + log.Info("migration cancelled by user") + return nil + } + } + + srv := &migrationServer{store: migStore, eventStore: migEventStore} + if err := migration.MigrateUsersToStaticConnectors(srv, connectorConfig); err != nil { + return fmt.Errorf("migrate users: %w", err) + } + + if !cfg.dryRun { + log.Info("DB migration completed successfully") + } + return nil +} + +// previewUsers counts pending vs already-migrated users and logs a summary. +// Returns the number of users still needing migration. +func previewUsers(ctx context.Context, migStore migration.Store) (int, error) { + users, err := migStore.ListUsers(ctx) + if err != nil { + return 0, fmt.Errorf("list users: %w", err) + } + + var pending, alreadyMigrated int + for _, u := range users { + if _, _, decErr := dex.DecodeDexUserID(u.Id); decErr == nil { + alreadyMigrated++ + } else { + pending++ + } + } + + log.Infof("found %d total users: %d pending migration, %d already migrated", len(users), pending, alreadyMigrated) + return pending, nil +} + +// confirmPrompt asks the user for interactive confirmation. Returns true if they accept. +func confirmPrompt(pending int) bool { + log.Infof("About to migrate %d users. This cannot be easily undone. Continue? [y/N] ", pending) + reader := bufio.NewReader(os.Stdin) + answer, _ := reader.ReadString('\n') + answer = strings.TrimSpace(strings.ToLower(answer)) + return answer == "y" || answer == "yes" +} + +// decodeConnectorConfig base64-decodes and JSON-unmarshals a connector. +func decodeConnectorConfig(encoded string) (*dex.Connector, error) { + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + + var conn dex.Connector + if err := json.Unmarshal(decoded, &conn); err != nil { + return nil, fmt.Errorf("json unmarshal: %w", err) + } + + if conn.ID == "" { + return nil, fmt.Errorf("connector ID is empty") + } + + return &conn, nil +} + +// generateConfig reads the existing management.json as raw JSON, removes +// IdpManagerConfig, adds EmbeddedIdP, updates HttpConfig fields, and writes +// the result. In dry-run mode, it prints the new config to stdout instead. +func generateConfig(cfg *migrationConfig, connectorConfig *dex.Connector) error { + // Read existing config as raw JSON to preserve all fields + raw, err := os.ReadFile(cfg.configPath) + if err != nil { + return fmt.Errorf("read config file: %w", err) + } + + var configMap map[string]any + if err := json.Unmarshal(raw, &configMap); err != nil { + return fmt.Errorf("parse config JSON: %w", err) + } + + // Remove unused information + delete(configMap, "IdpManagerConfig") + delete(configMap, "PKCEAuthorizationFlow") + delete(configMap, "DeviceAuthorizationFlow") + + httpConfig, ok := configMap["HttpConfig"].(map[string]any) + if httpConfig != nil && ok { + certFilePath := httpConfig["CertFile"] + certKeyPath := httpConfig["CertKey"] + + delete(configMap, "HttpConfig") + + configMap["HttpConfig"] = map[string]any{ + "CertFile": certFilePath, + "CertKey": certKeyPath, + } + } + + // Ensure the connector's redirectURI points to the management server (Dex callback), + // not the external IdP. The auto-detection may have used the IdP issuer URL. + connConfig := make(map[string]any, len(connectorConfig.Config)) + maps.Copy(connConfig, connectorConfig.Config) + + redirectURI, err := buildURL(cfg.apiURL, "/oauth2/callback") + if err != nil { + return fmt.Errorf("build redirect URI: %w", err) + } + connConfig["redirectURI"] = redirectURI + + issuer, err := buildURL(cfg.apiURL, "/oauth2") + if err != nil { + return fmt.Errorf("build issuer URL: %w", err) + } + + dashboardRedirectURL, err := buildURL(cfg.dashboardURL, "/nb-auth") + if err != nil { + return fmt.Errorf("build dashboard redirect URL: %w", err) + } + + dashboardSilentRedirectURL, err := buildURL(cfg.dashboardURL, "/nb-silent-auth") + if err != nil { + return fmt.Errorf("build dashboard silent redirect URL: %w", err) + } + + // Add minimal EmbeddedIdP section + configMap["EmbeddedIdP"] = map[string]any{ + "Enabled": true, + "Issuer": issuer, + "DashboardRedirectURIs": []string{ + dashboardRedirectURL, + dashboardSilentRedirectURL, + }, + "StaticConnectors": []any{ + map[string]any{ + "type": connectorConfig.Type, + "name": connectorConfig.Name, + "id": connectorConfig.ID, + "config": connConfig, + }, + }, + } + + newJSON, err := json.MarshalIndent(configMap, "", " ") + if err != nil { + return fmt.Errorf("marshal new config: %w", err) + } + + if cfg.dryRun { + log.Info("[DRY RUN] new management.json would be:") + log.Infoln(string(newJSON)) + return nil + } + + // Backup original + backupPath := cfg.configPath + ".bak" + if err := os.WriteFile(backupPath, raw, 0o600); err != nil { + return fmt.Errorf("write backup: %w", err) + } + log.Infof("backed up original config to %s", backupPath) + + // Write new config + if err := os.WriteFile(cfg.configPath, newJSON, 0o600); err != nil { + return fmt.Errorf("write new config: %w", err) + } + log.Infof("wrote new config to %s", cfg.configPath) + + return nil +} + +func buildURL(uri, path string) (string, error) { + // Case for domain without scheme, e.g. "example.com" or "example.com:8080" + if !strings.HasPrefix(uri, "http://") && !strings.HasPrefix(uri, "https://") { + uri = "https://" + uri + } + + val, err := url.JoinPath(uri, path) + if err != nil { + return "", err + } + + return val, nil +} + +func printPostMigrationInstructions(cfg *migrationConfig) { + authAuthority, err := buildURL(cfg.apiURL, "/oauth2") + if err != nil { + authAuthority = "https:///oauth2" + } + + log.Info("Congratulations! You have successfully migrated your NetBird management server to the embedded Dex IdP.") + log.Info("Next steps:") + log.Info("1. Make sure the following environment variables are set for your dashboard server:") + log.Infof(` +AUTH_AUDIENCE=netbird-dashboard +AUTH_CLIENT_ID=netbird-dashboard +AUTH_AUTHORITY=%s +AUTH_SUPPORTED_SCOPES=openid profile email groups +AUTH_REDIRECT_URI=/nb-auth +AUTH_SILENT_REDIRECT_URI=/nb-silent-auth + `, + authAuthority, + ) + log.Info("2. Make sure you restart the dashboard & management servers to pick up the new config and environment variables.") + log.Info("eg. docker compose up -d --force-recreate management dashboard") + log.Info("3. Optional: If you have a reverse proxy configured, make sure the path `/oauth2/*` points to the management api server.") +} + +// Compile-time check that migrationServer implements migration.Server. +var _ migration.Server = (*migrationServer)(nil) diff --git a/tools/idp-migrate/main_test.go b/tools/idp-migrate/main_test.go new file mode 100644 index 000000000..75d0bd7eb --- /dev/null +++ b/tools/idp-migrate/main_test.go @@ -0,0 +1,487 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/idp/migration" +) + +// TestMigrationServerInterface is a compile-time check that migrationServer +// implements the migration.Server interface. +func TestMigrationServerInterface(t *testing.T) { + var _ migration.Server = (*migrationServer)(nil) +} + +func TestDecodeConnectorConfig(t *testing.T) { + conn := dex.Connector{ + Type: "oidc", + Name: "test", + ID: "test-id", + Config: map[string]any{ + "issuer": "https://example.com", + "clientID": "cid", + "clientSecret": "csecret", + }, + } + + data, err := json.Marshal(conn) + require.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + + result, err := decodeConnectorConfig(encoded) + require.NoError(t, err) + assert.Equal(t, "test-id", result.ID) + assert.Equal(t, "oidc", result.Type) + assert.Equal(t, "https://example.com", result.Config["issuer"]) +} + +func TestDecodeConnectorConfig_InvalidBase64(t *testing.T) { + _, err := decodeConnectorConfig("not-valid-base64!!!") + require.Error(t, err) + assert.Contains(t, err.Error(), "base64 decode") +} + +func TestDecodeConnectorConfig_InvalidJSON(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not json")) + _, err := decodeConnectorConfig(encoded) + require.Error(t, err) + assert.Contains(t, err.Error(), "json unmarshal") +} + +func TestDecodeConnectorConfig_EmptyConnectorID(t *testing.T) { + conn := dex.Connector{ + Type: "oidc", + Name: "no-id", + ID: "", + } + data, err := json.Marshal(conn) + require.NoError(t, err) + + encoded := base64.StdEncoding.EncodeToString(data) + _, err = decodeConnectorConfig(encoded) + require.Error(t, err) + assert.Contains(t, err.Error(), "connector ID is empty") +} + +func TestValidateConfig(t *testing.T) { + valid := &migrationConfig{ + configPath: "/etc/netbird/management.json", + dataDir: "/var/lib/netbird", + idpSeedInfo: "some-base64", + apiURL: "https://api.example.com", + dashboardURL: "https://dash.example.com", + } + + t.Run("valid config", func(t *testing.T) { + require.NoError(t, validateConfig(valid)) + }) + + t.Run("missing configPath", func(t *testing.T) { + cfg := *valid + cfg.configPath = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--config") + }) + + t.Run("missing dataDir", func(t *testing.T) { + cfg := *valid + cfg.dataDir = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--datadir") + }) + + t.Run("missing idpSeedInfo", func(t *testing.T) { + cfg := *valid + cfg.idpSeedInfo = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--idp-seed-info") + }) + + t.Run("missing apiUrl", func(t *testing.T) { + cfg := *valid + cfg.apiURL = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--api-domain") + }) + + t.Run("missing dashboardUrl", func(t *testing.T) { + cfg := *valid + cfg.dashboardURL = "" + err := validateConfig(&cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "--dashboard-domain") + }) +} + +func TestConfigFromArgs_EnvVarsApplied(t *testing.T) { + t.Run("env vars fill in for missing flags", func(t *testing.T) { + t.Setenv("NETBIRD_CONFIG_PATH", "/env/management.json") + t.Setenv("NETBIRD_DATA_DIR", "/env/data") + t.Setenv("NETBIRD_IDP_SEED_INFO", "env-seed") + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "https://dash.env.com") + + cfg, err := configFromArgs([]string{}) + require.NoError(t, err) + + assert.Equal(t, "/env/management.json", cfg.configPath) + assert.Equal(t, "/env/data", cfg.dataDir) + assert.Equal(t, "env-seed", cfg.idpSeedInfo) + assert.Equal(t, "https://api.env.com", cfg.apiURL) + assert.Equal(t, "https://dash.env.com", cfg.dashboardURL) + }) + + t.Run("flags work without env vars", func(t *testing.T) { + cfg, err := configFromArgs([]string{ + "--config", "/flag/management.json", + "--datadir", "/flag/data", + "--idp-seed-info", "flag-seed", + "--api-url", "https://api.flag.com", + "--dashboard-url", "https://dash.flag.com", + }) + require.NoError(t, err) + + assert.Equal(t, "/flag/management.json", cfg.configPath) + assert.Equal(t, "/flag/data", cfg.dataDir) + assert.Equal(t, "flag-seed", cfg.idpSeedInfo) + assert.Equal(t, "https://api.flag.com", cfg.apiURL) + assert.Equal(t, "https://dash.flag.com", cfg.dashboardURL) + }) + + t.Run("env vars override flags", func(t *testing.T) { + t.Setenv("NETBIRD_CONFIG_PATH", "/env/management.json") + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + + cfg, err := configFromArgs([]string{ + "--config", "/flag/management.json", + "--datadir", "/flag/data", + "--idp-seed-info", "flag-seed", + "--api-url", "https://api.flag.com", + "--dashboard-url", "https://dash.flag.com", + }) + require.NoError(t, err) + + assert.Equal(t, "/env/management.json", cfg.configPath, "env should override flag") + assert.Equal(t, "https://api.env.com", cfg.apiURL, "env should override flag") + assert.Equal(t, "https://dash.flag.com", cfg.dashboardURL, "flag preserved when no env override") + }) + + t.Run("--domain flag with specific env var override", func(t *testing.T) { + t.Setenv("NETBIRD_API_URL", "https://api.env.com") + + cfg, err := configFromArgs([]string{ + "--domain", "both.flag.com", + "--config", "/path", + "--datadir", "/data", + "--idp-seed-info", "seed", + }) + require.NoError(t, err) + + assert.Equal(t, "https://api.env.com", cfg.apiURL, "specific env beats --domain") + assert.Equal(t, "both.flag.com", cfg.dashboardURL, "--domain fills dashboard") + }) +} + +func TestApplyOverrides_MostGranularWins(t *testing.T) { + t.Run("specific flags beat --domain", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.specific.com", + dashboardURL: "dash.specific.com", + } + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "api.specific.com", cfg.apiURL) + assert.Equal(t, "dash.specific.com", cfg.dashboardURL) + }) + + t.Run("--domain fills blanks when specific flags missing", func(t *testing.T) { + cfg := &migrationConfig{} + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "broad.com", cfg.apiURL) + assert.Equal(t, "broad.com", cfg.dashboardURL) + }) + + t.Run("--domain fills only the missing specific flag", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.specific.com", + } + applyOverrides(cfg, "broad.com") + + assert.Equal(t, "api.specific.com", cfg.apiURL) + assert.Equal(t, "broad.com", cfg.dashboardURL) + }) + + t.Run("NETBIRD_DOMAIN overrides flags", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.flag.com", + dashboardURL: "dash.flag.com", + } + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "env-broad.com", cfg.apiURL) + assert.Equal(t, "env-broad.com", cfg.dashboardURL) + }) + + t.Run("specific env vars beat NETBIRD_DOMAIN", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + t.Setenv("NETBIRD_API_URL", "api.env-specific.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "dash.env-specific.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "api.env-specific.com", cfg.apiURL) + assert.Equal(t, "dash.env-specific.com", cfg.dashboardURL) + }) + + t.Run("one specific env var overrides only its field", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DOMAIN", "env-broad.com") + t.Setenv("NETBIRD_API_URL", "api.env-specific.com") + + applyOverrides(cfg, "") + + assert.Equal(t, "api.env-specific.com", cfg.apiURL) + assert.Equal(t, "env-broad.com", cfg.dashboardURL) + }) + + t.Run("specific env vars beat all flags combined", func(t *testing.T) { + cfg := &migrationConfig{ + apiURL: "api.flag.com", + dashboardURL: "dash.flag.com", + } + t.Setenv("NETBIRD_API_URL", "api.env.com") + t.Setenv("NETBIRD_DASHBOARD_URL", "dash.env.com") + + applyOverrides(cfg, "domain-flag.com") + + assert.Equal(t, "api.env.com", cfg.apiURL) + assert.Equal(t, "dash.env.com", cfg.dashboardURL) + }) + + t.Run("env vars override all non-domain flags", func(t *testing.T) { + cfg := &migrationConfig{ + configPath: "/flag/path", + dataDir: "/flag/data", + idpSeedInfo: "flag-seed", + dryRun: false, + force: false, + skipConfig: false, + skipPopulateUserInfo: false, + logLevel: "info", + } + t.Setenv("NETBIRD_CONFIG_PATH", "/env/path") + t.Setenv("NETBIRD_DATA_DIR", "/env/data") + t.Setenv("NETBIRD_IDP_SEED_INFO", "env-seed") + t.Setenv("NETBIRD_DRY_RUN", "true") + t.Setenv("NETBIRD_FORCE", "true") + t.Setenv("NETBIRD_SKIP_CONFIG", "true") + t.Setenv("NETBIRD_SKIP_POPULATE_USER_INFO", "true") + t.Setenv("NETBIRD_LOG_LEVEL", "debug") + + applyOverrides(cfg, "") + + assert.Equal(t, "/env/path", cfg.configPath) + assert.Equal(t, "/env/data", cfg.dataDir) + assert.Equal(t, "env-seed", cfg.idpSeedInfo) + assert.True(t, cfg.dryRun) + assert.True(t, cfg.force) + assert.True(t, cfg.skipConfig) + assert.True(t, cfg.skipPopulateUserInfo) + assert.Equal(t, "debug", cfg.logLevel) + }) + + t.Run("boolean env vars properly parse false values", func(t *testing.T) { + cfg := &migrationConfig{} + t.Setenv("NETBIRD_DRY_RUN", "false") + t.Setenv("NETBIRD_FORCE", "yes") + t.Setenv("NETBIRD_SKIP_CONFIG", "0") + + applyOverrides(cfg, "") + + assert.False(t, cfg.dryRun) + assert.False(t, cfg.force) + assert.False(t, cfg.skipConfig) + }) + + t.Run("unset env vars do not override flags", func(t *testing.T) { + cfg := &migrationConfig{ + configPath: "/flag/path", + dataDir: "/flag/data", + idpSeedInfo: "flag-seed", + dryRun: true, + logLevel: "warn", + } + + applyOverrides(cfg, "") + + assert.Equal(t, "/flag/path", cfg.configPath) + assert.Equal(t, "/flag/data", cfg.dataDir) + assert.Equal(t, "flag-seed", cfg.idpSeedInfo) + assert.True(t, cfg.dryRun) + assert.Equal(t, "warn", cfg.logLevel) + }) +} + +func TestBuildUrl(t *testing.T) { + tests := []struct { + name string + uri string + path string + expected string + }{ + {"with https scheme", "https://example.com", "/oauth2", "https://example.com/oauth2"}, + {"with http scheme", "http://example.com", "/oauth2/callback", "http://example.com/oauth2/callback"}, + {"bare domain", "example.com", "/oauth2", "https://example.com/oauth2"}, + {"domain with port", "example.com:8080", "/nb-auth", "https://example.com:8080/nb-auth"}, + {"trailing slash on uri", "https://example.com/", "/oauth2", "https://example.com/oauth2"}, + {"nested path", "https://example.com", "/oauth2/callback", "https://example.com/oauth2/callback"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url, err := buildURL(tt.uri, tt.path) + assert.NoError(t, err) + assert.Equal(t, tt.expected, url) + }) + } +} + +func TestGenerateConfig(t *testing.T) { + t.Run("generates valid config", func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "management.json") + + originalConfig := `{ + "Datadir": "/var/lib/netbird", + "HttpConfig": { + "LetsEncryptDomain": "mgmt.example.com", + "CertFile": "/etc/ssl/cert.pem", + "CertKey": "/etc/ssl/key.pem", + "AuthIssuer": "https://zitadel.example.com/oauth2", + "AuthKeysLocation": "https://zitadel.example.com/oauth2/keys", + "OIDCConfigEndpoint": "https://zitadel.example.com/.well-known/openid-configuration", + "AuthClientID": "old-client-id", + "AuthUserIDClaim": "preferred_username" + }, + "IdpManagerConfig": { + "ManagerType": "zitadel", + "ClientConfig": { + "Issuer": "https://zitadel.example.com", + "ClientID": "zit-id", + "ClientSecret": "zit-secret" + } + } +}` + require.NoError(t, os.WriteFile(configPath, []byte(originalConfig), 0o600)) + + cfg := &migrationConfig{ + configPath: configPath, + dashboardURL: "https://mgmt.example.com", + apiURL: "https://mgmt.example.com", + } + conn := &dex.Connector{ + Type: "zitadel", + Name: "zitadel", + ID: "zitadel", + Config: map[string]any{ + "issuer": "https://zitadel.example.com", + "clientID": "zit-id", + "clientSecret": "zit-secret", + }, + } + + err := generateConfig(cfg, conn) + require.NoError(t, err) + + // Check backup was created + backupPath := configPath + ".bak" + backupData, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(backupData)) + + // Read and parse the new config + newData, err := os.ReadFile(configPath) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(newData, &result)) + + // IdpManagerConfig should be removed + _, hasOldIdp := result["IdpManagerConfig"] + assert.False(t, hasOldIdp, "IdpManagerConfig should be removed") + + _, hasPKCE := result["PKCEAuthorizationFlow"] + assert.False(t, hasPKCE, "PKCEAuthorizationFlow should be removed") + + // EmbeddedIdP should be present with minimal fields + embeddedIDP, ok := result["EmbeddedIdP"].(map[string]any) + require.True(t, ok, "EmbeddedIdP should be present") + assert.Equal(t, true, embeddedIDP["Enabled"]) + assert.Equal(t, "https://mgmt.example.com/oauth2", embeddedIDP["Issuer"]) + assert.Nil(t, embeddedIDP["LocalAuthDisabled"], "LocalAuthDisabled should not be set") + assert.Nil(t, embeddedIDP["SignKeyRefreshEnabled"], "SignKeyRefreshEnabled should not be set") + assert.Nil(t, embeddedIDP["CLIRedirectURIs"], "CLIRedirectURIs should not be set") + + // Static connector's redirectURI should use the management domain + connectors := embeddedIDP["StaticConnectors"].([]any) + require.Len(t, connectors, 1) + firstConn := connectors[0].(map[string]any) + connCfg := firstConn["config"].(map[string]any) + assert.Equal(t, "https://mgmt.example.com/oauth2/callback", connCfg["redirectURI"], + "redirectURI should be overridden to use the management domain") + + // HttpConfig should only have CertFile and CertKey + httpConfig, ok := result["HttpConfig"].(map[string]any) + require.True(t, ok, "HttpConfig should be present") + assert.Equal(t, "/etc/ssl/cert.pem", httpConfig["CertFile"]) + assert.Equal(t, "/etc/ssl/key.pem", httpConfig["CertKey"]) + assert.Nil(t, httpConfig["AuthIssuer"], "AuthIssuer should be stripped") + + // Datadir should be preserved + assert.Equal(t, "/var/lib/netbird", result["Datadir"]) + }) + + t.Run("dry run does not write files", func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "management.json") + + originalConfig := `{"HttpConfig": {"CertFile": "", "CertKey": ""}}` + require.NoError(t, os.WriteFile(configPath, []byte(originalConfig), 0o600)) + + cfg := &migrationConfig{ + configPath: configPath, + dashboardURL: "https://mgmt.example.com", + apiURL: "https://mgmt.example.com", + dryRun: true, + } + conn := &dex.Connector{Type: "oidc", Name: "test", ID: "test"} + + err := generateConfig(cfg, conn) + require.NoError(t, err) + + // Original should be unchanged + data, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(data)) + + // No backup should exist + _, err = os.Stat(configPath + ".bak") + assert.True(t, os.IsNotExist(err)) + }) +} diff --git a/upload-server/server/local.go b/upload-server/server/local.go index f12c472d2..f7ca50011 100644 --- a/upload-server/server/local.go +++ b/upload-server/server/local.go @@ -7,6 +7,7 @@ import ( "net/url" "os" "path/filepath" + "strings" log "github.com/sirupsen/logrus" @@ -82,15 +83,18 @@ func (l *local) getUploadURL(objectKey string) (string, error) { return newURL.String(), nil } +const maxUploadSize = 150 << 20 + func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } + r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize) body, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, fmt.Sprintf("failed to read body: %v", err), http.StatusInternalServerError) + http.Error(w, "request body too large or failed to read", http.StatusRequestEntityTooLarge) return } @@ -105,20 +109,47 @@ func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { return } - dirPath := filepath.Join(l.dir, uploadDir) - err = os.MkdirAll(dirPath, 0750) - if err != nil { + cleanBase := filepath.Clean(l.dir) + string(filepath.Separator) + + dirPath := filepath.Clean(filepath.Join(l.dir, uploadDir)) + if !strings.HasPrefix(dirPath, cleanBase) { + http.Error(w, "invalid path", http.StatusBadRequest) + log.Warnf("Path traversal attempt blocked (dir): %s", dirPath) + return + } + + filePath := filepath.Clean(filepath.Join(dirPath, uploadFile)) + if !strings.HasPrefix(filePath, cleanBase) { + http.Error(w, "invalid path", http.StatusBadRequest) + log.Warnf("Path traversal attempt blocked (file): %s", filePath) + return + } + + if err = os.MkdirAll(dirPath, 0750); err != nil { http.Error(w, "failed to create upload dir", http.StatusInternalServerError) log.Errorf("Failed to create upload dir: %v", err) return } - file := filepath.Join(dirPath, uploadFile) - if err := os.WriteFile(file, body, 0600); err != nil { - http.Error(w, "failed to write file", http.StatusInternalServerError) - log.Errorf("Failed to write file %s: %v", file, err) + flags := os.O_WRONLY | os.O_CREATE | os.O_EXCL + f, err := os.OpenFile(filePath, flags, 0600) + if err != nil { + if os.IsExist(err) { + http.Error(w, "file already exists", http.StatusConflict) + return + } + http.Error(w, "failed to create file", http.StatusInternalServerError) + log.Errorf("Failed to create file %s: %v", filePath, err) return } - log.Infof("Uploading file %s", file) + defer func() { _ = f.Close() }() + + if _, err = f.Write(body); err != nil { + http.Error(w, "failed to write file", http.StatusInternalServerError) + log.Errorf("Failed to write file %s: %v", filePath, err) + return + } + + log.Infof("Uploaded file %s", filePath) w.WriteHeader(http.StatusOK) } diff --git a/upload-server/server/local_test.go b/upload-server/server/local_test.go index bd8a87809..64b8fd228 100644 --- a/upload-server/server/local_test.go +++ b/upload-server/server/local_test.go @@ -63,3 +63,90 @@ func Test_LocalHandlePutRequest(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +func Test_LocalHandlePutRequest_PathTraversal(t *testing.T) { + mockDir := t.TempDir() + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + fileContent := []byte("malicious content") + req := httptest.NewRequest(http.MethodPut, putURLPath+"/uploads/%2e%2e%2f%2e%2e%2fetc%2fpasswd", bytes.NewReader(fileContent)) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + _, err = os.Stat(filepath.Join(mockDir, "..", "..", "etc", "passwd")) + require.True(t, os.IsNotExist(err), "traversal file should not exist") +} + +func Test_LocalHandlePutRequest_DirTraversal(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + l := &local{url: "http://localhost:8080", dir: mockDir} + + body := bytes.NewReader([]byte("bad")) + req := httptest.NewRequest(http.MethodPut, putURLPath+"/x/evil.txt", body) + req.SetPathValue("dir", "../../../tmp") + req.SetPathValue("file", "evil.txt") + + rec := httptest.NewRecorder() + l.handlePutRequest(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + _, err := os.Stat(filepath.Join("/tmp", "evil.txt")) + require.True(t, os.IsNotExist(err), "traversal file should not exist outside store dir") +} + +func Test_LocalHandlePutRequest_DuplicateFile(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPut, putURLPath+"/dir/dup.txt", bytes.NewReader([]byte("first"))) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + req = httptest.NewRequest(http.MethodPut, putURLPath+"/dir/dup.txt", bytes.NewReader([]byte("second"))) + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + require.Equal(t, http.StatusConflict, rec.Code) + + content, err := os.ReadFile(filepath.Join(mockDir, "dir", "dup.txt")) + require.NoError(t, err) + require.Equal(t, []byte("first"), content) +} + +func Test_LocalHandlePutRequest_BodyTooLarge(t *testing.T) { + mockDir := t.TempDir() + t.Setenv("SERVER_URL", "http://localhost:8080") + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + largeBody := make([]byte, maxUploadSize+1) + req := httptest.NewRequest(http.MethodPut, putURLPath+"/dir/big.txt", bytes.NewReader(largeBody)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + + _, err = os.Stat(filepath.Join(mockDir, "dir", "big.txt")) + require.True(t, os.IsNotExist(err)) +} diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go index 26b0ecd09..7ab1bb379 100644 --- a/upload-server/server/s3_test.go +++ b/upload-server/server/s3_test.go @@ -5,13 +5,12 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "os" "runtime" "testing" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -20,45 +19,55 @@ import ( ) func Test_S3HandlerGetUploadURL(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" { - t.Skip("Skipping test on non-Linux and CI environment due to docker dependency") - } - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows due to potential docker dependency") + if runtime.GOOS != "linux" { + t.Skip("Skipping test on non-Linux due to docker dependency") } - awsEndpoint := "http://127.0.0.1:4566" awsRegion := "us-east-1" ctx := context.Background() - containerRequest := testcontainers.ContainerRequest{ - Image: "localstack/localstack:s3-latest", - ExposedPorts: []string{"4566:4566/tcp"}, - WaitingFor: wait.ForLog("Ready"), - } - c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: containerRequest, - Started: true, + ContainerRequest: testcontainers.ContainerRequest{ + Image: "minio/minio:RELEASE.2025-04-22T22-12-26Z", + ExposedPorts: []string{"9000/tcp"}, + Env: map[string]string{ + "MINIO_ROOT_USER": "minioadmin", + "MINIO_ROOT_PASSWORD": "minioadmin", + }, + Cmd: []string{"server", "/data"}, + WaitingFor: wait.ForHTTP("/minio/health/ready").WithPort("9000"), + }, + Started: true, }) - if err != nil { - t.Error(err) - } - defer func(c testcontainers.Container, ctx context.Context) { + require.NoError(t, err) + t.Cleanup(func() { if err := c.Terminate(ctx); err != nil { t.Log(err) } - }(c, ctx) + }) + + mappedPort, err := c.MappedPort(ctx, "9000") + require.NoError(t, err) + + hostIP, err := c.Host(ctx) + require.NoError(t, err) + + awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port() t.Setenv("AWS_REGION", awsRegion) t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) - t.Setenv("AWS_ACCESS_KEY_ID", "test") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test") + t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin") + t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin") + t.Setenv("AWS_CONFIG_FILE", "") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "") + t.Setenv("AWS_PROFILE", "") - cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint)) - if err != nil { - t.Error(err) - } + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(awsRegion), + config.WithBaseEndpoint(awsEndpoint), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("minioadmin", "minioadmin", "")), + ) + require.NoError(t, err) client := s3.NewFromConfig(cfg, func(o *s3.Options) { o.UsePathStyle = true @@ -66,19 +75,16 @@ func Test_S3HandlerGetUploadURL(t *testing.T) { }) bucketName := "test" - if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{ + _, err = client.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: &bucketName, - }); err != nil { - t.Error(err) - } + }) + require.NoError(t, err) list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) - if err != nil { - t.Error(err) - } + require.NoError(t, err) - assert.Equal(t, len(list.Buckets), 1) - assert.Equal(t, *list.Buckets[0].Name, bucketName) + require.Len(t, list.Buckets, 1) + require.Equal(t, bucketName, *list.Buckets[0].Name) t.Setenv(bucketVar, bucketName) diff --git a/util/log.go b/util/log.go index 03547024a..b1de2d999 100644 --- a/util/log.go +++ b/util/log.go @@ -43,7 +43,13 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error { var writers []io.Writer logFmt := os.Getenv("NB_LOG_FORMAT") + seen := make(map[string]bool, len(logs)) for _, logPath := range logs { + if seen[logPath] { + continue + } + seen[logPath] = true + switch logPath { case LogSyslog: AddSyslogHookToLogger(logger)
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}